DingZhenDojoCat commited on
Commit
bb7f76d
·
verified ·
1 Parent(s): e1b7c71

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README.md +187 -0
  3. create_data.py +370 -0
  4. get_parquet_data.py +48 -0
  5. images/curves.png +3 -0
  6. images/demo2.png +3 -0
  7. images/frames.png +0 -0
  8. images/performance.png +3 -0
  9. merge_data.py +71 -0
  10. move.sh +5 -0
  11. move_eval.sh +5 -0
  12. move_result.sh +1 -0
  13. previous_version/Video-R1-main-previous/src/distill_r1/README.md +54 -0
  14. previous_version/Video-R1-main-previous/src/distill_r1/filter_r1.py +153 -0
  15. previous_version/Video-R1-main-previous/src/distill_r1/prompt.py +39 -0
  16. previous_version/Video-R1-main-previous/src/r1-v/.gitignore +178 -0
  17. previous_version/Video-R1-main-previous/src/r1-v/LICENSE +201 -0
  18. previous_version/Video-R1-main-previous/src/r1-v/Makefile +20 -0
  19. previous_version/Video-R1-main-previous/src/r1-v/configs/ddp.yaml +16 -0
  20. previous_version/Video-R1-main-previous/src/r1-v/configs/qwen2vl_sft_config.yaml +37 -0
  21. previous_version/Video-R1-main-previous/src/r1-v/configs/zero2.yaml +21 -0
  22. previous_version/Video-R1-main-previous/src/r1-v/configs/zero3.yaml +22 -0
  23. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/create_vision_cot_data.py +153 -0
  24. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/lmms_eval_qwen2vl.sh +61 -0
  25. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/prepare_hf_data.py +166 -0
  26. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/train_aria_moe.sh +68 -0
  27. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/train_qwen2_vl.sh +61 -0
  28. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero2.json +41 -0
  29. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero3.json +41 -0
  30. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero3.yaml +22 -0
  31. previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero3_offload.json +48 -0
  32. previous_version/Video-R1-main-previous/src/r1-v/run_grpo.sh +29 -0
  33. previous_version/Video-R1-main-previous/src/r1-v/setup.cfg +41 -0
  34. previous_version/Video-R1-main-previous/src/r1-v/setup.py +132 -0
  35. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/__init__.py +0 -0
  36. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/evaluate.py +85 -0
  37. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/generate.py +156 -0
  38. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/grpo.py +229 -0
  39. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/sft.py +322 -0
  40. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/trainer/__init__.py +4 -0
  41. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/trainer/grpo_trainer.py +652 -0
  42. previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer.py +832 -0
  43. setup.sh +19 -0
  44. src/download.py +24 -0
  45. src/eval_bench.py +277 -0
  46. src/eval_bench.sh +36 -0
  47. src/eval_bench_4567.sh +32 -0
  48. src/generate_cot_vllm.py +266 -0
  49. src/inference_example.py +93 -0
  50. src/scripts/run_grpo_video.sh +44 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/curves.png filter=lfs diff=lfs merge=lfs -text
37
+ images/demo2.png filter=lfs diff=lfs merge=lfs -text
38
+ images/performance.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Video-R1: Reinforcing Video Reasoning in MLLMs
2
+
3
+ [[📖 Paper](https://arxiv.org/pdf/2503.21776)] [[🤗 Video-R1-7B-model](https://huggingface.co/Video-R1/Video-R1-7B)] [[🤗 Video-R1-train-data](https://huggingface.co/datasets/Video-R1/Video-R1-data)]
4
+ [[🤖 Video-R1-7B-model](https://modelscope.cn/models/Video-R1/Video-R1-7B)] [[🤖 Video-R1-train-data](https://modelscope.cn/datasets/Video-R1/Video-R1-data)]
5
+
6
+
7
+
8
+ ## 👀 About Video-R1
9
+
10
+ Inspired by DeepSeek-R1's success in eliciting reasoning abilities through rule-based RL, we introduce Video-R1 as **the first work to *systematically* explore the R1 paradigm for eliciting video reasoning** within MLLMs.
11
+
12
+ We introduce T-GRPO, an extension of GRPO that incorporates temporal modeling to **explicitly promote temporal reasoning**. Besides, We constructed two datasets: **Video-R1-COT-165k** for SFT cold start and **Video-R1-260k** for RL training, both comprising image and video data.
13
+
14
+ Our Video-R1-7B obtain strong performance on several video reasoning benchmarks. For example, Video-R1-7B attains a 35.8% accuracy on video spatial reasoning benchmark VSI-bench, **surpassing the commercial proprietary model GPT-4o**.
15
+
16
+ Video-R1-7B **can be easily trained** using 4 H20 (96GB) GPUs, or 5 A100 (80G) GPUs.
17
+
18
+
19
+
20
+ ## 🔥 News
21
+ - [2025/05/28] Our Video-R1-7B achieves **36.5%** accuracy on the new video reasoning benchmark [**Video-Holmes**](https://video-holmes.github.io/Page.github.io/), beating the commercial model **o4-mini (29.9%)** and **Gemini-2.0-Flash (30.6%)**.
22
+ - [2025/03/28] We release our paper, codes, model weights, and two curated training datasets in huggingface🤗 and modelscope🤖.
23
+ - [2025/02/23] We release the preliminary version of Video-R1, you can refer to `./previous_version` for this version.
24
+
25
+ ## 📍 Features
26
+
27
+ + Support Qwen2.5-VL
28
+ + Support vLLM training and inference
29
+ + Support Image-Video mixed training
30
+ + Support multiple types for answers output (multiple choice, numerical, OCR, free-form, regression)
31
+ + Provide full pipeline (dataset, COT annotation, SFT training, RL training, evaluation, etc)
32
+
33
+ ## 🔍 Dataset
34
+
35
+ To overcome the scarcity of high-quality video reasoning training data, we strategically introduce image-based reasoning data as part of training data. We collect data from a variety of public datasets and carefully sample and balance the proportion of each subset.
36
+
37
+ <img src="./images/dataset.png" style="zoom:80%;" />
38
+
39
+ To facilitate an effective SFT cold start, we leverage Qwen2.5-VL-72B to generate COT rationales for the samples in Video-R1-260k. After applying basic rule-based filtering to remove low-quality or inconsistent outputs, we obtain a high-quality CoT dataset, Video-R1-COT 165k.
40
+
41
+ ## 🏆 Performance
42
+
43
+ <img src="./images/performance.png" style="zoom:80%;" />
44
+
45
+ Video-R1 significantly outperforms previous models across most benchmarks. Notably, on VSI-Bench, which focuses on spatial reasoning in videos, Video-R1-7B achieves a new state-of-the-art accuracy of 35.8%, surpassing GPT-4o, a proprietary model, while using only 32 frames and 7B parameters.
46
+
47
+ This highlights the necessity of explicit reasoning capability in solving video tasks, and confirms the effectiveness of reinforcement learning for video tasks.
48
+
49
+
50
+ <div align="center">
51
+ <img src="./images/frames.png" alt="Descriptive alt text" width="85%">
52
+ </div>
53
+
54
+ Besides, although the model is trained using only 16 frames, we find that evaluating on more frames (e.g., 64) generally leads to better performance, particularly on benchmarks with longer videos. These results indicate the importance of training models to reason over more frames.
55
+
56
+
57
+ ## 🧠 Aha Moment in Video Reasoning
58
+
59
+ One of the most intriguing outcomes of reinforcement learning in Video-R1 is the emergence of self-reflection reasoning behaviors, commonly referred to as “aha moments”. Some examples are as follows.
60
+
61
+ <img src="./images/demo1.png" style="zoom:80%;" />
62
+
63
+ <img src="./images/demo2.png" style="zoom:80%;" />
64
+
65
+
66
+ ## 📈 RL Training Curves
67
+
68
+ The accuracy reward exhibits a generally upward trend, indicating that the model continuously improves its ability to produce correct answers under RL.
69
+
70
+ Interestingly, the response length curve first drops at the beginning of RL training, then gradually increases. We guess this is because the model initially discards its previous, potentially sub-optimal reasoning style. Then gradually converges to a better and stable reasoning policy.
71
+
72
+ <img src="./images/curves.png" style="zoom:80%;" />
73
+
74
+
75
+
76
+ ## 📐 Set up
77
+
78
+ ```bash
79
+ git clone https://github.com/tulerfeng/Video-R1
80
+ cd Video-R1
81
+
82
+ # build environment
83
+ conda create -n video-r1 python=3.11
84
+ conda activate video-r1
85
+ bash setup.sh
86
+
87
+ # qwen video extraction setting, e.g., max frames, resolutions
88
+ # Use the [decord] feature to improve speed
89
+ cd src/qwen-vl-utils
90
+ pip install -e .[decord]
91
+ cd ..
92
+
93
+ # download training dataset
94
+ git lfs install
95
+ git clone https://huggingface.co/datasets/Video-R1/Video-R1-data
96
+ ```
97
+
98
+ Please put the downloaded dataset to `src/r1-v/Video-R1-data/`
99
+
100
+ Then, unzip the data
101
+
102
+ ```
103
+ python ./src/unzip.py
104
+ ```
105
+
106
+ The `Video-R1-260k.json` file is for RL training while `Video-R1-COT-165k.json` is for SFT cold start.
107
+
108
+ Qwen2.5-VL has been frequently updated in the Transformers library, which may cause version-related bugs or inconsistencies. Our code is compatible with the following version, please download at [here](https://drive.google.com/file/d/1Kc81WZitEhUZYWXpL6y2GXuSXufLSYcF/view?usp=sharing)
109
+
110
+ Then install our provided version of transformers
111
+
112
+ ```bash
113
+ unzip transformers-main.zip
114
+ cd ./transformers-main
115
+ pip install .
116
+ ```
117
+
118
+ For vLLM library, please use 0.7.2 version.
119
+
120
+ For trl library, please use 0.16.0 version.
121
+
122
+ ## 🚀 Training
123
+
124
+ We first perform supervised fine-tuning on the Video-R1-COT-165k dataset for one epoch to obtain the Qwen2.5-VL-7B-SFT model. If you want to perform CoT annotation on your own data, please refer to `src/generate_cot_vllm.py`
125
+
126
+ ```bash
127
+ bash ./src/scripts/run_sft_video.sh
128
+ ```
129
+ If you want to skip the SFT process, we also provide one of our SFT models at [🤗Qwen2.5-VL-SFT](https://huggingface.co/Video-R1/Qwen2.5-VL-7B-COT-SFT).
130
+
131
+ This is followed by RL training on the Video-R1-260k dataset to produce the final Video-R1 model. Due to current computational resource limitations, we train the model for only 1.2k RL steps.
132
+
133
+ The script for training the obtained Qwen2.5-VL-7B-SFT model with T-GRPO or GRPO is as follows
134
+
135
+ ```bash
136
+ bash ./src/scripts/run_grpo_video.sh
137
+ ```
138
+
139
+ You can also use the following script to enable vLLM acceleration for RL training
140
+
141
+ ```bash
142
+ bash ./src/scripts/run_grpo_vllm_qwen25vl.sh
143
+ ```
144
+
145
+ For efficiency considerations, we limit the maximum number of video frames to 16 during training. Each frame is processed at a max resolution of 128 × 28 × 28. You can set this in `src/qwen-vl-utils`
146
+
147
+ Please keep per_device_train_batch_size=1 as in previous work r1-v
148
+
149
+ ## 🔮 Inference & Evaluation
150
+
151
+ During inference, we increase the max frame resolution to 256 × 28 × 28 and max frames to 16/32/64 to enhance performance. You can easily set this in `src/qwen-vl-utils`
152
+
153
+ For all evaluations, we follow the decoding configuration used in the official Qwen2.5-VL demo, with top\_p = 0.001 and temperature = 0.01. Setting large top_p may encounter messy output when inference.
154
+
155
+ We recommend using our provided json files and scripts for easier evaluation.
156
+
157
+ The json files can be downloaded at: [[🤗 Video-R1-eval](https://huggingface.co/datasets/Video-R1/Video-R1-eval)], put them in `/src/r1-v/Evaluation`
158
+
159
+ Next, download the evaluation video data from each benchmark’s official website, and place them in `/src/r1-v/Evaluation` as specified in the provided json files.
160
+
161
+ Finally, conduct evaluation on all benchmarks using the following scripts
162
+
163
+ ```bash
164
+ bash ./src/eval_bench.sh
165
+ ```
166
+ For infernce on a single example, you may use:
167
+
168
+ ```bash
169
+ python ./src/inference_example.py
170
+ ```
171
+
172
+ ## Acknowledgements
173
+
174
+ We sincerely appreciate the contributions of the open-source community. The related projects are as follows: [R1-V](https://github.com/Deep-Agent/R1-V) , [DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1)
175
+
176
+ ## Citations
177
+
178
+ If you find our work helpful for your research, please consider citing our work.
179
+
180
+ ```
181
+ @article{feng2025video,
182
+ title={Video-R1: Reinforcing Video Reasoning in MLLMs},
183
+ author={Feng, Kaituo and Gong, Kaixiong and Li, Bohao and Guo, Zonghao and Wang, Yibing and Peng, Tianshuo and Wang, Benyou and Yue, Xiangyu},
184
+ journal={arXiv preprint arXiv:2503.21776},
185
+ year={2025}
186
+ }
187
+ ```
create_data.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import re
2
+ # from pathlib import Path
3
+ # from datasets import load_dataset, Dataset, DatasetDict, Features, Value, Image
4
+ # import re
5
+ # from typing import Dict, List, Optional
6
+ # from pathlib import Path
7
+ # from datasets import Dataset, DatasetDict, concatenate_datasets, Features, Value, Sequence
8
+
9
+
10
+ # # ------------------------------------------------------------------
11
+ # # 0) Load your JSON → `raw_ds` exactly as before
12
+ # # ------------------------------------------------------------------
13
+
14
+ # files = [
15
+ # "pool_multiple_choice_chunk_01.json",
16
+ # "pool_multiple_choice_chunk_02.json",
17
+ # "pool_multiple_choice_chunk_03.json",
18
+ # "pool_multiple_choice_chunk_04.json",
19
+ # "pool_numerical_chunk_01.json",
20
+ # "pool_numerical_chunk_02.json",
21
+ # "pool_numerical_chunk_03.json",
22
+ # "pool_regression_chunk_01.json",
23
+ # ]
24
+
25
+ # # ---- 1-4. load, trim, normalise ----------------------------------------
26
+ # def load_trim_normalise(fp, cap=10_000):
27
+ # ds = Dataset.from_json(fp)
28
+
29
+ # # a) truncate
30
+ # ds = ds.select(range(min(cap, len(ds))))
31
+
32
+ # # b) make sure `options` exists and is always list[str]
33
+ # if "options" not in ds.column_names:
34
+ # ds = ds.add_column("options", [[]] * len(ds))
35
+ # else:
36
+ # ds = ds.map(
37
+ # lambda ex: {"options": [str(o) for o in (ex["options"] or [])]},
38
+ # remove_columns=[], num_proc=4,
39
+ # )
40
+
41
+ # return ds
42
+
43
+ # ds_list = [load_trim_normalise(fp) for fp in files]
44
+
45
+ # # ---- 4. align feature schema explicitly (all files now identical) -------
46
+ # common_features = Features({
47
+ # "problem_id" : Value("int64"),
48
+ # "problem" : Value("string"),
49
+ # "data_type" : Value("string"),
50
+ # "problem_type": Value("string"),
51
+ # "options" : Sequence(Value("string")),
52
+ # "solution" : Value("string"),
53
+ # "path" : Value("string"),
54
+ # "data_source" : Value("string"),
55
+ # })
56
+ # ds_list = [d.cast(common_features) for d in ds_list]
57
+
58
+ # # ---- 5. concatenate -----------------------------------------------------
59
+ # raw_train = concatenate_datasets(ds_list)
60
+ # raw_ds = DatasetDict({"train": raw_train})
61
+
62
+ # # ------------------------------------------------------------------
63
+ # # 1) Build the question (unchanged)
64
+ # # ------------------------------------------------------------------
65
+ # def build_question(example):
66
+ # q = (
67
+ # example["problem"] + " Options:\n" + "\n".join(example["options"])
68
+ # if example["problem_type"] == "multiple choice"
69
+ # else example["problem"]
70
+ # )
71
+ # example["problem"] = q
72
+ # return example
73
+
74
+
75
+ # def extract_answer(predict: str) -> Optional[str]:
76
+ # """
77
+ # Extracts the content of the <answer>…</answer> block from `predict`.
78
+ # Returns the inner text (with leading/trailing whitespace stripped),
79
+ # or None if no <answer> tag is found.
80
+ # """
81
+ # match = re.search(r"<answer>([\s\S]*?)</answer>", predict, re.DOTALL)
82
+ # if not match:
83
+ # return predict
84
+ # return match.group(1).strip()
85
+
86
+
87
+
88
+ # def add_answer(example):
89
+ # # assumes the ground-truth answer (tagged) is in `solution`
90
+ # example["answer"] = extract_answer(example["solution"])
91
+ # return example
92
+
93
+ # # ------------------------------------------------------------------
94
+ # # 3) Embed image bytes (column name stays "images")
95
+ # # ------------------------------------------------------------------
96
+ # def to_embedded_image(example):
97
+ # if example["data_type"] != "image":
98
+ # example["images"] = None
99
+ # return example
100
+ # with open(example["path"], "rb") as f:
101
+ # img_bytes = f.read()
102
+ # example["images"] = {"bytes": img_bytes, "path": None}
103
+ # return example
104
+
105
+ # # ------------------------------------------------------------------
106
+ # # 4) Full pipeline
107
+ # # ------------------------------------------------------------------
108
+ # processed = (
109
+ # raw_ds["train"]
110
+ # .map(build_question, num_proc=4)
111
+ # .map(add_answer, num_proc=4)
112
+ # .map(to_embedded_image, num_proc=4)
113
+ # .remove_columns([
114
+ # "path", "data_type", "options", "problem_type", "solution",
115
+ # "problem_id", "data_source" # ← drop these too
116
+ # ])
117
+ # )
118
+
119
+ # # ------------------------------------------------------------------
120
+ # # 5) Schema must match the final column names
121
+ # # ------------------------------------------------------------------
122
+ # features = Features({
123
+ # "problem": Value("string"),
124
+ # "answer" : Value("string"),
125
+ # "images" : Image(), # keep plural name
126
+ # })
127
+ # processed = processed.cast(features)
128
+
129
+ # # ------------------------------------------------------------------
130
+ # # 6) Write Parquet shards (file prefix inside the folder)
131
+ # # ------------------------------------------------------------------
132
+ # out_dir = Path("qwen2.5_vl_portable")
133
+ # out_dir.mkdir(parents=True, exist_ok=True)
134
+
135
+ # # processed.to_parquet(str(out_dir / "train.parquet")) # → train-00000-of-00001.parquet
136
+ # processed.to_parquet(str("./hf_data/train.parquet"))
137
+ # print("✓ Dataset written with embedded images and answers →", out_dir.resolve())
138
+
139
+
140
+ # import re
141
+ # from pathlib import Path
142
+ # from typing import Dict, List, Optional
143
+
144
+ # from datasets import (
145
+ # Dataset,
146
+ # DatasetDict,
147
+ # concatenate_datasets,
148
+ # Features,
149
+ # Value,
150
+ # Sequence,
151
+ # Image,
152
+ # )
153
+
154
+ # # ------------------------------------------------------------------
155
+ # # 0) Inputs
156
+ # # ------------------------------------------------------------------
157
+ # files = [
158
+ # "pool_multiple_choice_chunk_01.json",
159
+ # "pool_multiple_choice_chunk_02.json",
160
+ # "pool_multiple_choice_chunk_03.json",
161
+ # "pool_multiple_choice_chunk_04.json",
162
+ # "pool_numerical_chunk_01.json",
163
+ # "pool_numerical_chunk_02.json",
164
+ # "pool_numerical_chunk_03.json",
165
+ # "pool_regression_chunk_01.json",
166
+ # ]
167
+
168
+ # # ------------------------------------------------------------------
169
+ # # 1) Define common meta schema (what you want to keep in the output)
170
+ # # ------------------------------------------------------------------
171
+ # common_features = Features({
172
+ # "problem_id" : Value("int64"),
173
+ # "problem" : Value("string"),
174
+ # "data_type" : Value("string"),
175
+ # "problem_type": Value("string"),
176
+ # "options" : Sequence(Value("string")),
177
+ # "solution" : Value("string"),
178
+ # "path" : Value("string"),
179
+ # "data_source" : Value("string"),
180
+ # })
181
+
182
+ # # Final (superset) schema to write: meta + new columns
183
+ # full_features = common_features.copy()
184
+ # full_features["answer"] = Value("string")
185
+ # full_features["images"] = Image() # plural name kept, binary-friendly
186
+
187
+
188
+ # # ------------------------------------------------------------------
189
+ # # 2) Load + normalize each JSON
190
+ # # ------------------------------------------------------------------
191
+ # def load_trim_normalise(fp: str, cap: int = 10_000) -> Dataset:
192
+ # ds = Dataset.from_json(fp)
193
+
194
+ # # truncate if desired
195
+ # ds = ds.select(range(min(cap, len(ds))))
196
+
197
+ # # ensure `options` exists and is always list[str]
198
+ # if "options" not in ds.column_names:
199
+ # ds = ds.add_column("options", [[]] * len(ds))
200
+ # else:
201
+ # ds = ds.map(
202
+ # lambda ex: {"options": [str(o) for o in (ex["options"] or [])]},
203
+ # remove_columns=[],
204
+ # num_proc=4,
205
+ # )
206
+
207
+ # # align to the common meta schema early (helps concat)
208
+ # # Some JSONs may not have all fields; add missing with defaults first.
209
+ # missing_cols = [k for k in common_features.keys() if k not in ds.column_names]
210
+ # for mc in missing_cols:
211
+ # # create sensible defaults
212
+ # if mc == "options":
213
+ # ds = ds.add_column(mc, [[]] * len(ds))
214
+ # elif common_features[mc].dtype == "int64":
215
+ # ds = ds.add_column(mc, [0] * len(ds))
216
+ # else:
217
+ # ds = ds.add_column(mc, [""] * len(ds))
218
+
219
+ # ds = ds.cast(common_features)
220
+ # return ds
221
+
222
+ # ds_list = [load_trim_normalise(fp) for fp in files]
223
+
224
+ # # Concatenate shards
225
+ # raw_train = concatenate_datasets(ds_list)
226
+ # raw_ds = DatasetDict({"train": raw_train})
227
+
228
+
229
+ # # ------------------------------------------------------------------
230
+ # # 3) Processing fns
231
+ # # ------------------------------------------------------------------
232
+ # def build_question(example: Dict) -> Dict:
233
+ # """
234
+ # If multiple-choice, append the options to the text.
235
+ # Overwrites the `problem` field in-place (kept in output).
236
+ # """
237
+ # if example["problem_type"] == "multiple choice":
238
+ # opts = example.get("options") or []
239
+ # q = example["problem"] + " Options:\n" + "\n".join(opts)
240
+ # example["problem"] = q
241
+ # return example
242
+
243
+
244
+ # def extract_answer(predict: str) -> Optional[str]:
245
+ # """
246
+ # Return inner text of <answer>...</answer>, stripped.
247
+ # If no tag is found, return the original string.
248
+ # """
249
+ # if predict is None:
250
+ # return None
251
+ # match = re.search(r"<answer>([\s\S]*?)</answer>", predict, re.DOTALL)
252
+ # if not match:
253
+ # return predict
254
+ # return match.group(1).strip()
255
+
256
+
257
+ # def add_answer(example: Dict) -> Dict:
258
+ # example["answer"] = extract_answer(example.get("solution", ""))
259
+ # return example
260
+
261
+
262
+ # def to_embedded_image(example: Dict) -> Dict:
263
+ # """
264
+ # If data_type == 'image', embed bytes for HF Image() feature.
265
+ # Otherwise leave as None.
266
+ # """
267
+ # if example.get("data_type") != "image":
268
+ # example["images"] = None
269
+ # return example
270
+
271
+ # path = example.get("path")
272
+ # if not path:
273
+ # example["images"] = None
274
+ # return example
275
+
276
+ # try:
277
+ # with open(path, "rb") as f:
278
+ # img_bytes = f.read()
279
+ # example["images"] = {"bytes": img_bytes, "path": None}
280
+ # except Exception:
281
+ # # If image is missing or unreadable, keep None so cast still works
282
+ # example["images"] = None
283
+ # return example
284
+
285
+
286
+ # # ------------------------------------------------------------------
287
+ # # 4) Apply pipeline (do NOT drop meta columns you want to keep)
288
+ # # ------------------------------------------------------------------
289
+ # processed = (
290
+ # raw_ds["train"]
291
+ # .map(build_question, num_proc=4)
292
+ # .map(add_answer, num_proc=4)
293
+ # .map(to_embedded_image, num_proc=4)
294
+ # .cast(full_features) # <- ensure final schema
295
+ # )
296
+
297
+ # # Optional: control output column ordering
298
+ # processed = processed.select_columns(list(full_features.keys()))
299
+
300
+ # # ------------------------------------------------------------------
301
+ # # 5) Write Parquet
302
+ # # ------------------------------------------------------------------
303
+ # out_dir = Path("./hf_data")
304
+ # out_dir.mkdir(parents=True, exist_ok=True)
305
+
306
+ # out_path = out_dir / "train.parquet"
307
+ # processed.to_parquet(str(out_path))
308
+
309
+ # print("✓ Wrote:", out_path.resolve())
310
+ # print("Columns:", list(processed.features.keys()))
311
+
312
+
313
+ # ------------------------------------------------------------------
314
+ # 4.1) Downsample to 30k, mainly reducing math-heavy sources
315
+ # ------------------------------------------------------------------
316
+ from collections import Counter
317
+
318
+ TARGET_SIZE = 30_000
319
+ MATH_SHARE = 0.20 # keep ~20% math (tweak if you want)
320
+ SEED = 2025
321
+
322
+ # Define which sources are "mathy"
323
+ MATH_SOURCES = {
324
+ "Multimath-300k",
325
+ "TabMWP",
326
+ "Geometry3K",
327
+ "CLEVR-Math",
328
+ "DVQA",
329
+ "FigureQA",
330
+ "ChartQA",
331
+ "PlotQA",
332
+ "EXAMS-V-train/Mathematics",
333
+ "UniGeo",
334
+ "GeoQA+",
335
+ }
336
+
337
+ def is_math_source(name: Optional[str]) -> bool:
338
+ if not name:
339
+ return False
340
+ return name in MATH_SOURCES or ("math" in name.lower())
341
+
342
+ # Split
343
+ math_ds = processed.filter(lambda ex: is_math_source(ex.get("data_source")), num_proc=4)
344
+ non_math_ds = processed.filter(lambda ex: not is_math_source(ex.get("data_source")), num_proc=4)
345
+
346
+ # Decide quotas
347
+ non_math_quota = min(len(non_math_ds), int(TARGET_SIZE * (1 - MATH_SHARE)))
348
+ math_quota = TARGET_SIZE - non_math_quota
349
+ math_quota = min(math_quota, len(math_ds)) # guard if math is too small
350
+
351
+ # Sample deterministically
352
+ non_math_sample = non_math_ds.shuffle(seed=SEED).select(range(non_math_quota))
353
+ math_sample = math_ds.shuffle(seed=SEED).select(range(math_quota))
354
+
355
+ # Combine and shuffle
356
+ final = concatenate_datasets([non_math_sample, math_sample]).shuffle(seed=SEED)
357
+
358
+ # Quick sanity printout
359
+ cnt = Counter(final["data_source"])
360
+ total = len(final)
361
+ print(f"Final size: {total} (non-math {non_math_quota}, math {math_quota})")
362
+ for name, n in sorted(cnt.items(), key=lambda x: -x[1])[:25]:
363
+ pct = n / total
364
+ print(f"{name:30s} {n:6d} {pct:7.3%}")
365
+
366
+ # Use this 'final' dataset for writing
367
+ processed = final
368
+ out_path = out_dir / "train_30k.parquet"
369
+ processed.to_parquet(str(out_path))
370
+ print("✓ Wrote:", out_path.resolve())
get_parquet_data.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import io
3
+ from datasets import Dataset, Features, Sequence, Value, Image
4
+ from PIL import Image as PILImage
5
+
6
+ # 1️⃣ Load your JSON file (which is a top-level list of dicts)
7
+ with open("Train_QA_10k_noFreeForm.json", "r") as f:
8
+ records = json.load(f) # List[Dict]
9
+
10
+ # 2️⃣ Build an HF Dataset
11
+ ds = Dataset.from_list(records)
12
+
13
+ # 3️⃣ Read each image file into raw bytes
14
+ def read_image_bytes(example):
15
+ with open(example["path"], "rb") as img_f:
16
+ example["image_bytes"] = img_f.read()
17
+ return example
18
+
19
+ # we keep all original columns + add "image_bytes"
20
+ ds = ds.map(read_image_bytes, remove_columns=[])
21
+
22
+ # 4️⃣ Define your schema, telling HF that image_bytes is binary
23
+ features = Features({
24
+ "problem_id": Value("int64"),
25
+ "problem": Value("string"),
26
+ "data_type": Value("string"),
27
+ "problem_type": Value("string"),
28
+ "options": Sequence(Value("string")),
29
+ "solution": Value("string"),
30
+ "data_source": Value("string"),
31
+ # "prompt": Value("string"),
32
+ "answer": Value("string"),
33
+ "path": Value("string"),
34
+ "image_bytes": Value("binary"), # ← raw bytes in Arrow
35
+ })
36
+ ds = ds.cast(features)
37
+
38
+ # 5️⃣ Rename, and cast that byte-column to an Image feature that decodes to PIL
39
+ ds = ds.rename_column("image_bytes", "images")
40
+ ds = ds.cast_column("images", Image(decode=True))
41
+
42
+ # 6️⃣ Sanity-check
43
+ img0 = ds[0]["images"]
44
+ print(img0)
45
+ # → PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x384
46
+
47
+ # 7️⃣ Finally, write out to Parquet (the bytes go in the file)
48
+ ds.to_parquet("./hf_data/Train_QA_10k_noFreeForm.parquet")
images/curves.png ADDED

Git LFS Details

  • SHA256: 72f6c9fdc2b4e59df266b31e85cf186ee1acc9b508b9001f709f05873dd00b20
  • Pointer size: 131 Bytes
  • Size of remote file: 277 kB
images/demo2.png ADDED

Git LFS Details

  • SHA256: 8a6dde7c88f100fdffcc01142dd997d1f53f6255d0d4d5ebe36808d8733d6280
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
images/frames.png ADDED
images/performance.png ADDED

Git LFS Details

  • SHA256: bf4f04aa5954b18a90471b992cf50d45ba1289bf1412ff77e6d5557b88c09659
  • Pointer size: 131 Bytes
  • Size of remote file: 445 kB
merge_data.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Iterator, Dict
4
+
5
+ # -----------------------------
6
+ # Inputs
7
+ # -----------------------------
8
+ files = [
9
+ "pool_multiple_choice_chunk_01.json",
10
+ "pool_multiple_choice_chunk_02.json",
11
+ "pool_multiple_choice_chunk_03.json",
12
+ "pool_multiple_choice_chunk_04.json",
13
+ "pool_numerical_chunk_01.json",
14
+ "pool_numerical_chunk_02.json",
15
+ "pool_numerical_chunk_03.json",
16
+ "pool_regression_chunk_01.json",
17
+ ]
18
+
19
+ out_path = Path("merged_train.json")
20
+
21
+ # -----------------------------
22
+ # Read records from JSON/JSONL
23
+ # -----------------------------
24
+ def iter_records(path: Path) -> Iterator[Dict]:
25
+ """
26
+ Yields records from a file that can be:
27
+ - JSONL (one JSON object per line), or
28
+ - a single JSON array, or
29
+ - a single JSON object.
30
+ """
31
+ text = path.read_text(encoding="utf-8")
32
+ # Try whole-file JSON first (array or object)
33
+ try:
34
+ data = json.loads(text)
35
+ if isinstance(data, list):
36
+ for rec in data:
37
+ yield rec
38
+ elif isinstance(data, dict):
39
+ yield data
40
+ else:
41
+ raise ValueError(f"Unsupported top-level JSON type in {path}")
42
+ except json.JSONDecodeError:
43
+ # Fallback: treat as JSONL
44
+ for i, line in enumerate(text.splitlines(), 1):
45
+ line = line.strip()
46
+ if not line:
47
+ continue
48
+ try:
49
+ yield json.loads(line)
50
+ except json.JSONDecodeError as e:
51
+ raise ValueError(f"Invalid JSON on line {i} in {path}: {e}") from e
52
+
53
+ # -----------------------------
54
+ # Merge & write single JSON file
55
+ # -----------------------------
56
+ out_path.parent.mkdir(parents=True, exist_ok=True)
57
+
58
+ count = 0
59
+ with out_path.open("w", encoding="utf-8") as out:
60
+ out.write("[\n")
61
+ first = True
62
+ for fp in files:
63
+ for rec in iter_records(Path(fp)):
64
+ if not first:
65
+ out.write(",\n")
66
+ out.write(json.dumps(rec, ensure_ascii=False))
67
+ first = False
68
+ count += 1
69
+ out.write("\n]")
70
+
71
+ print(f"✓ Wrote {count} records to {out_path.resolve()}")
move.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ cp -r /cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/scripts/ ./src/
2
+
3
+ cp -r /cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/ ./src/r1-v/src/
4
+
5
+ cp -r /cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/local_scripts/ ./src/r1-v/
move_eval.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ cp /cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/eval_bench.py ./src/
2
+
3
+ cp /cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/eval_bench.sh ./src/
4
+
5
+ cp /cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/eval_bench_4567.sh ./src/
move_result.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ cp -r ./src/r1-v/eval_results/* /cq_1/share_1603164/user/zongxia/workspace/A-EVALUTION/video_eval_results/
previous_version/Video-R1-main-previous/src/distill_r1/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # R1 Reasoning Dataset Generation
2
+
3
+
4
+
5
+ ## QA Pairs Generation
6
+
7
+ We create a `scene description` by combining the objects (with meta info such as location, depth) using a template.
8
+
9
+ We keep the couting relevant questions and add a `How many items are there in the described scene?` question to count all objects in the scene.
10
+
11
+ Example QA pair:
12
+
13
+ ```json
14
+ {'img_filename': 'CLEVR_trainA_048403.png',
15
+ 'question': 'How many things are both on the right side of the big yellow rubber thing and left of the purple ball?',
16
+ 'answer': '5',
17
+ 'description': 'Scene Description:\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\n'}
18
+ ```
19
+
20
+ See `generate_scene_qa_pairs.ipynb` for details.
21
+
22
+ ## R1 Response Generation
23
+
24
+ `query_r1.py`: We get the predictions and obtain the reasoning trace of R1 by querying SilconFlow (You can register via https://cloud.siliconflow.cn/i/rYfeQZ6A so both we 20M token credit!)
25
+
26
+ Example Response:
27
+
28
+ ```json
29
+ <think>Okay, let's see. The user is asking how many items are there in the described scene. Let me go through the scene description step by step.
30
+
31
+ So, the scene description lists each object with details like color, material, shape, rotation, 3D coordinates, and pixel coordinates. Each entry starts with "A" which usually indicates one item each. Let me count each one.
32
+
33
+ First entry: "A small green metal cylinder..." That's one. Second: "A small blue rubber cylinder..." Second item. Third: "A small cyan rubber cylinder..." That's three. Fourth: "A large cyan metal sphere..." Four. Fifth: "A large brown metal cube..." Five. Sixth: "A large yellow rubber cube..." Six. Seventh: "A large brown rubber cylinder..." That's seven.
34
+
35
+ Wait, did I miss any? Let me check again. The list has entries from "A small green..." up to the seventh one. Each sentence starts with "A", which suggests each is a separate item. No commas separating multiple items in a single entry. Each has different attributes and coordinates, so they must all be distinct.
36
+
37
+ So the answer should be 7 items.
38
+ </think>
39
+
40
+ There are 7 items in the described scene. Each entry corresponds to one distinct object, listed by their properties, coordinates, and rotations.
41
+ ```
42
+
43
+ In `v1` we do not constrain the output format and thus the answers are somewhat messy to parse. We then switched to `v2` by explicitly prompting the model to generate the answer with `**The answer is: **`
44
+
45
+ ## Reasoning Path Filtering
46
+
47
+ `filter_r1.py`: We filter out (almost) valid reasoning trace, by juding whether the R1 answer is correct (following our previous work [Math-Shepherd](https://arxiv.org/abs/2312.08935)).
48
+
49
+ ## HF dataset creation
50
+
51
+ Finally, we create the dataset using `create_hf_dataset.py` and upload to HF dataset hub.
52
+
53
+
54
+
previous_version/Video-R1-main-previous/src/distill_r1/filter_r1.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from pathlib import Path
4
+
5
+
6
+
7
+ def extract_answer_from_query(query_results: str) -> str | None:
8
+ """
9
+ Extract answer from query results, specifically looking for:
10
+ - Numbers within asterisks
11
+ - Yes/No answers in various formats
12
+
13
+ Args:
14
+ query_results: String containing the query response
15
+
16
+ Returns:
17
+ Extracted answer string or None if no answer found
18
+ """
19
+ # First try to find answers in the standard format with labels
20
+ # Split the text into segments (trying to get the last conclusion)
21
+ if "<think>" not in query_results or "</think>" not in query_results:
22
+ return None
23
+ segments = query_results.split("\n")
24
+
25
+ # First try to find final conclusion in the last few segments
26
+ conclusion_patterns = [
27
+ r"(?:so|therefore|thus|hence),?\s*(?:the answer is\s+)?\*\*\s*(no|yes|[0-9]+)\s*\*\*",
28
+ r"(?:so|therefore|thus|hence),?\s*(?:the answer is\s+)?(no|yes|[0-9]+)\b",
29
+ r"the answer is\s+\*\*\s*(no|yes|[0-9]+)\s*\*\*",
30
+ r"(?:final|conclusive) answer(?:\s+is)?\s*\*\*\s*(no|yes|[0-9]+)\s*\*\*",
31
+ ]
32
+
33
+ # Try to find conclusion in last 3 segments
34
+ for segment in reversed(segments[-3:]):
35
+ for pattern in conclusion_patterns:
36
+ match = re.search(pattern, segment, re.IGNORECASE)
37
+ if match:
38
+ return match.group(1).strip().lower()
39
+
40
+ # If no conclusion found, try other patterns on the full text
41
+ labeled_patterns = [
42
+ r"\*\*The answer is:\s*\*\*\s*([0-9]+|yes|no)\b",
43
+ r"\*\*Answer:\s*\*\*\s*([0-9]+|yes|no)\b",
44
+ r"\*\*Answer\*\*:\s*([0-9]+|yes|no)\b",
45
+ r"\*\*Answer:?\s*\*\*\s*There (?:is|are)\s+([0-9]+)",
46
+ r"\*\*Final Count:\s*\*\*\s*([0-9]+)",
47
+ r"\*\*Final Count:\s*\*\*\s*([0-9]+)\s+(?:items?|objects?|spheres?|cubes?|boxes?)",
48
+ r"\*\*Total:\s*\*\*\s*([0-9]+)",
49
+ r"The answer is:\s*([0-9]+|yes|no)\b",
50
+ r"Answer:\s*([0-9]+|yes|no)\b",
51
+ r"should be\s+([0-9]+)[.\s]",
52
+ ]
53
+
54
+ direct_patterns = [
55
+ r"\*\*\s*([0-9]+)\s*\*\*",
56
+ r"\*\*\s*([0-9]+)\s+(?:items?|objects?|cubes?|boxes?|spheres?)?\s*\*\*",
57
+ r"\*\*\s*([0-9]+)\s+[^*]+\*\*",
58
+ ]
59
+
60
+ latex_patterns = [
61
+ r"\$\\boxed{([0-9]+)}\$",
62
+ r"\\boxed{([0-9]+)}",
63
+ ]
64
+
65
+ count_patterns = [
66
+ r"There (?:is|are)\s+([0-9]+)\s+(?:items?|objects?|spheres?|cubes?|boxes?)",
67
+ ]
68
+
69
+ # Try all patterns in sequence on full text
70
+ all_patterns = labeled_patterns + direct_patterns + latex_patterns + count_patterns
71
+
72
+ for pattern in all_patterns:
73
+ match = re.search(pattern, query_results, re.IGNORECASE)
74
+ if match:
75
+ return match.group(1).strip().lower()
76
+
77
+ return None
78
+
79
+
80
+ def validate_qa_pairs(input_file: str, output_dir: str, verbose: bool = True):
81
+ """
82
+ Process QA pairs and save them to separate files.
83
+ Only saves pairs where parsed answer matches ground truth.
84
+
85
+ Args:
86
+ input_file: Path to input JSONL file
87
+ output_dir: Directory to save output files
88
+ verbose: If True, print examples of mismatched or unparseable responses
89
+ """
90
+ output_dir = Path(output_dir)
91
+ output_dir.mkdir(parents=True, exist_ok=True)
92
+
93
+ valid_pairs = []
94
+ invalid_pairs = []
95
+ stats = {"total": 0, "unparseable": 0, "mismatch": 0, "valid": 0}
96
+
97
+ with open(input_file, "r", encoding="utf-8") as f:
98
+ for line_num, line in enumerate(f, 1):
99
+ stats["total"] += 1
100
+ qa_pair = json.loads(line.strip())
101
+ ground_truth = str(qa_pair.get("a", "")).lower().strip()
102
+ parsed_answer = extract_answer_from_query(qa_pair["r1_response"])
103
+
104
+ if parsed_answer is None:
105
+ stats["unparseable"] += 1
106
+ qa_pair["error"] = "unparseable"
107
+ invalid_pairs.append(qa_pair)
108
+ if verbose:
109
+ print(f"\nLine {line_num}: Could not parse answer")
110
+ print(f"Ground truth: {ground_truth}")
111
+ print(f"Query results: {qa_pair['r1_response'][-200:]}...")
112
+ elif parsed_answer != ground_truth:
113
+ stats["mismatch"] += 1
114
+ qa_pair["error"] = "mismatch"
115
+ qa_pair["parsed_answer"] = parsed_answer
116
+ invalid_pairs.append(qa_pair)
117
+ if verbose:
118
+ print(f"\nLine {line_num}: Answer mismatch")
119
+ print(f"Ground truth: {ground_truth}")
120
+ print(f"Parsed answer: {parsed_answer}")
121
+ print(f"Query results: {qa_pair['r1_response'][-200:]}...")
122
+ else:
123
+ stats["valid"] += 1
124
+ valid_pairs.append(qa_pair)
125
+
126
+ # Save valid pairs (where parsed answer matches ground truth)
127
+ valid_file = output_dir / "valid_pairs.jsonl"
128
+ with open(valid_file, "w", encoding="utf-8") as f:
129
+ for pair in valid_pairs:
130
+ f.write(json.dumps(pair, ensure_ascii=False) + "\n")
131
+
132
+ # Save invalid pairs (unparseable or mismatched)
133
+ invalid_file = output_dir / "invalid_pairs.jsonl"
134
+ with open(invalid_file, "w", encoding="utf-8") as f:
135
+ for pair in invalid_pairs:
136
+ f.write(json.dumps(pair, ensure_ascii=False) + "\n")
137
+
138
+ # Print statistics
139
+ print(f"\nProcessing Summary:")
140
+ print(f"Total pairs processed: {stats['total']}")
141
+ print(f"Valid pairs (matching ground truth): {stats['valid']}")
142
+ print(f"Invalid pairs: {stats['unparseable'] + stats['mismatch']}")
143
+ print(f" - Unparseable: {stats['unparseable']}")
144
+ print(f" - Answer mismatch: {stats['mismatch']}")
145
+ print(f"\nOutput files:")
146
+ print(f"Valid pairs saved to: {valid_file}")
147
+ print(f"Invalid pairs saved to: {invalid_file}")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ validate_qa_pairs(
152
+ "r1_results_clevr_cogent_v1.0_trainA_v2.jsonl", "filter_results_v2"
153
+ ) # "filtered_output_tmp_v1.jsonl")
previous_version/Video-R1-main-previous/src/distill_r1/prompt.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ R1_SYS_PROMPT = """You are DeepSeek-R1, an AI assistant created exclusively by the Chinese Company DeepSeek. You'll provide helpful, harmless, and detailed responses to all user inquiries. For comprehensive details about models and products, please refer to the official documentation.
2
+
3
+ Key Guidelines:
4
+ Identity & Compliance
5
+
6
+ Clearly state your identity as a DeepSeek AI assistant in initial responses.
7
+
8
+ Comply with Chinese laws and regulations, including data privacy requirements.
9
+
10
+ Capability Scope
11
+
12
+ Handle both Chinese and English queries effectively
13
+
14
+ Acknowledge limitations for real-time information post knowledge cutoff (2023-12)
15
+
16
+ Provide technical explanations for AI-related questions when appropriate
17
+
18
+ Response Quality
19
+
20
+ Give comprehensive, logically structured answers
21
+
22
+ Use markdown formatting for clear information organization
23
+
24
+ Admit uncertainties for ambiguous queries
25
+
26
+ Ethical Operation
27
+
28
+ Strictly refuse requests involving illegal activities, violence, or explicit content
29
+
30
+ Maintain political neutrality according to company guidelines
31
+
32
+ Protect user privacy and avoid data collection
33
+
34
+ Specialized Processing
35
+
36
+ Use <think>...</think> tags for internal reasoning before responding
37
+
38
+ Employ XML-like tags for structured output when required
39
+ """
previous_version/Video-R1-main-previous/src/r1-v/.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Temp folders
174
+ data/
175
+ wandb/
176
+ scripts/
177
+ checkpoints/
178
+ .vscode/
previous_version/Video-R1-main-previous/src/r1-v/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
previous_version/Video-R1-main-previous/src/r1-v/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style quality
2
+
3
+ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4
+ export PYTHONPATH = src
5
+
6
+ check_dirs := src
7
+
8
+ style:
9
+ black --line-length 119 --target-version py310 $(check_dirs) setup.py
10
+ isort $(check_dirs) setup.py
11
+
12
+ quality:
13
+ black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14
+ isort --check-only $(check_dirs) setup.py
15
+ flake8 --max-line-length 119 $(check_dirs) setup.py
16
+
17
+
18
+ # Evaluation
19
+
20
+ evaluate:
previous_version/Video-R1-main-previous/src/r1-v/configs/ddp.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
previous_version/Video-R1-main-previous/src/r1-v/configs/qwen2vl_sft_config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: Qwen/Qwen2-VL-2B-Instruct
3
+ model_revision: main
4
+ torch_dtype: bfloat16
5
+
6
+ # Data training arguments
7
+ dataset_name: /GEOQA_R1V_Train_8K
8
+ dataset_configs:
9
+ - all
10
+ preprocessing_num_workers: 4
11
+
12
+ # SFT trainer config
13
+ bf16: true
14
+ do_eval: true
15
+ eval_strategy: "no"
16
+ gradient_accumulation_steps: 4
17
+ gradient_checkpointing: true
18
+ gradient_checkpointing_kwargs:
19
+ use_reentrant: false
20
+ learning_rate: 2.0e-05
21
+ log_level: info
22
+ logging_steps: 5
23
+ logging_strategy: steps
24
+ lr_scheduler_type: cosine
25
+ packing: true
26
+ max_seq_length: 4096
27
+ max_steps: -1
28
+ num_train_epochs: 1
29
+ output_dir: ./log/Qwen2-VL-2B-Instruct-SFT
30
+ overwrite_output_dir: true
31
+ per_device_eval_batch_size: 1
32
+ per_device_train_batch_size: 1
33
+ report_to:
34
+ - wandb
35
+ save_strategy: "no"
36
+ seed: 42
37
+ warmup_ratio: 0.1
previous_version/Video-R1-main-previous/src/r1-v/configs/zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 4
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
previous_version/Video-R1-main-previous/src/r1-v/configs/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/create_vision_cot_data.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import concurrent.futures
4
+ import io
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from functools import partial
12
+ from io import BytesIO
13
+ from typing import Dict, List
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import pandas as pd
18
+ from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19
+ from tqdm import tqdm
20
+
21
+ import bytedtos
22
+ import seaborn as sns
23
+ import yaml
24
+ from openai import AzureOpenAI
25
+ from PIL import Image
26
+ from pillow_avif import AvifImagePlugin
27
+
28
+
29
+ PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30
+
31
+ Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32
+
33
+ Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34
+
35
+ Input Format:
36
+ Original Question: {original_question}
37
+ Original Answer: {original_answer}
38
+
39
+ Output Format:
40
+ Question: [rewrite the question if necessary]
41
+ Answer: [answer with reasoning steps, including calculations where applicable]
42
+ <think>step-by-step reasoning process</think>
43
+ <answer>easy to verify answer</answer>
44
+ """
45
+
46
+
47
+ def get_image_data_url(image_input):
48
+ if isinstance(image_input, str) and image_input.startswith("data:"):
49
+ return image_input
50
+
51
+ if isinstance(image_input, str) and image_input.startswith("http"):
52
+ image_input = load_image(image_input)
53
+
54
+ if isinstance(image_input, str):
55
+ image_input = Image.open(image_input)
56
+
57
+ if not isinstance(image_input, Image.Image):
58
+ raise ValueError("Unsupported image input type")
59
+
60
+ if image_input.mode != "RGB":
61
+ image_input = image_input.convert("RGB")
62
+
63
+ buffer = BytesIO()
64
+ image_input.save(buffer, format="JPEG")
65
+ img_bytes = buffer.getvalue()
66
+ base64_data = base64.b64encode(img_bytes).decode("utf-8")
67
+ return f"data:image/jpeg;base64,{base64_data}"
68
+
69
+
70
+ def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71
+ if image is None:
72
+ return None
73
+
74
+ data_url_list = [get_image_data_url(image)]
75
+ client = AzureOpenAI(
76
+ azure_endpoint="YOUR_AZURE_ENDPOINT",
77
+ api_version="2023-07-01-preview",
78
+ api_key="YOUR_API_KEY",
79
+ )
80
+
81
+ for attempt in range(max_retries):
82
+ try:
83
+ messages = [
84
+ {
85
+ "role": "system",
86
+ "content": "You are an expert to analyze the image and provide useful information for users.",
87
+ },
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ {"type": "text", "text": prompt},
92
+ ],
93
+ },
94
+ ]
95
+
96
+ for data_url in data_url_list:
97
+ messages[1]["content"].insert(
98
+ 0, {"type": "image_url", "image_url": {"url": data_url}}
99
+ )
100
+
101
+ response = client.chat.completions.create(
102
+ model="gpt-4o-2024-08-06",
103
+ messages=messages,
104
+ temperature=0.2,
105
+ max_tokens=8192,
106
+ )
107
+ return response.choices[0].message.content
108
+
109
+ except Exception as e:
110
+ if attempt == max_retries - 1:
111
+ raise Exception(
112
+ f"Failed after {max_retries} attempts. Last error: {str(e)}"
113
+ )
114
+ delay = initial_delay * (2**attempt) + random.uniform(
115
+ 0, 0.1 * initial_delay * (2**attempt)
116
+ )
117
+ time.sleep(delay)
118
+
119
+
120
+ def process_single_item(example):
121
+ try:
122
+ image_path = example["image_path"]
123
+ formatted_prompt = PROMPT_FORMAT.format(
124
+ original_question=example["question"], original_answer=example["answer"]
125
+ )
126
+
127
+ response = gpt4o_query(image_path, formatted_prompt)
128
+ example["gpt4o_response"] = response
129
+ return example
130
+ except Exception as e:
131
+ print(f"Error processing item: {str(e)}")
132
+ example["gpt4o_response"] = None
133
+ return example
134
+
135
+
136
+ def main():
137
+ dataset_path = "path/to/your/dataset"
138
+ full_dataset = load_from_disk(dataset_path)
139
+
140
+ processed_dataset = full_dataset.map(
141
+ function=partial(process_single_item),
142
+ num_proc=256,
143
+ desc="Processing dataset with GPT-4o",
144
+ keep_in_memory=True,
145
+ )
146
+
147
+ output_path = f"{dataset_path}_processed"
148
+ processed_dataset.save_to_disk(output_path)
149
+ print(f"Processed dataset saved to: {output_path}")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/lmms_eval_qwen2vl.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export HF_HOME="<CACHE_DIR>"
2
+ export HF_TOKEN="<HF_TOKEN>"
3
+ export HF_HUB_ENABLE_HF_TRANSFER="1"
4
+
5
+ export API_TYPE="<API_TYPE>"
6
+ export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
7
+ export AZURE_API_KEY="<API_KEY>"
8
+ export API_VERSION="<API_VERSION>"
9
+ export MODEL_VERSION="<MODEL_VERSION>"
10
+ export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11
+
12
+ # Prompt for installation with 3-second timeout
13
+ read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14
+ if [ "$install_deps" = "YES" ]; then
15
+ # Prepare the environment
16
+ pip3 install --upgrade pip
17
+ pip3 install -U setuptools
18
+
19
+ cd <PROJECT_ROOT>
20
+ if [ ! -d "maas_engine" ]; then
21
+ git clone <REPO_URL>
22
+ else
23
+ echo "maas_engine directory already exists, skipping clone"
24
+ fi
25
+ cd maas_engine
26
+ git pull
27
+ git checkout <BRANCH_NAME>
28
+ pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29
+
30
+ current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31
+ if [ "$current_version" != "4.46.2" ]; then
32
+ echo "Installing transformers 4.46.2 (current version: $current_version)"
33
+ pip3 install transformers==4.46.2
34
+ else
35
+ echo "transformers 4.46.2 is already installed"
36
+ fi
37
+
38
+ cd <LMMS_EVAL_DIR>
39
+ rm -rf <TARGET_DIR>
40
+ pip3 install -e .
41
+ pip3 install -U pydantic
42
+ pip3 install Levenshtein
43
+ pip3 install nltk
44
+ python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45
+ fi
46
+
47
+ TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48
+ MODEL_BASENAME=qwen2_vl
49
+
50
+ model_checkpoint="<MODEL_CHECKPOINT_PATH>"
51
+ echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52
+ cd <LMMS_EVAL_DIR>
53
+
54
+ python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55
+ --model qwen2_vl \
56
+ --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57
+ --tasks ${TASKS} \
58
+ --batch_size 1 \
59
+ --log_samples \
60
+ --log_samples_suffix ${MODEL_BASENAME} \
61
+ --output_path ./logs
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/prepare_hf_data.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import pandas as pd
4
+ import random
5
+ from typing import List, Dict
6
+ import numpy as np
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from tqdm import tqdm
9
+ import datasets
10
+
11
+ import io
12
+ from datasets import load_dataset, load_from_disk, concatenate_datasets
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from functools import partial
16
+ from pillow_avif import AvifImagePlugin
17
+ from datasets import Dataset
18
+ import json
19
+ import yaml
20
+ import os
21
+ import re
22
+ import time
23
+ import random
24
+ import base64
25
+ from openai import AzureOpenAI
26
+ import concurrent.futures
27
+ from typing import List, Dict
28
+ import argparse
29
+ import time
30
+
31
+
32
+ def extract_problem_solution(gpt4o_response):
33
+ # Split the response into parts
34
+ parts = gpt4o_response.split("<think>")
35
+
36
+ # Extract the problem (first part before any <think> tags)
37
+ problem = parts[0].strip()
38
+ # Remove "Question:" prefix if it exists
39
+ problem = re.sub(r"^Question:\s*", "", problem)
40
+ # Remove "Answer:" at the end of the problem
41
+ problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42
+
43
+ # Combine all the reasoning steps into a single <think> block
44
+ think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
45
+ solution = f"<think>{' '.join(think_parts)}</think>"
46
+
47
+ # Add the final answer if it exists, removing "Answer:" prefix
48
+ if "<answer>" in gpt4o_response:
49
+ final_answer = (
50
+ gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
51
+ )
52
+ final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53
+ solution += f"\n\n<answer>{final_answer}</answer>"
54
+
55
+ return problem, solution
56
+
57
+
58
+ def load_image_from_path(image_path):
59
+ try:
60
+ img = Image.open(image_path)
61
+ return img
62
+ except Exception as e:
63
+ print(f"Error loading image {image_path}: {str(e)}")
64
+ return None
65
+
66
+
67
+ def process_raw_data(raw_data):
68
+ # Parse the raw data if it's a string
69
+ if isinstance(raw_data, str):
70
+ data = json.loads(raw_data)
71
+ else:
72
+ data = raw_data
73
+
74
+ # Extract problem and solution
75
+ try:
76
+ problem, solution = extract_problem_solution(data["gpt4o_response"])
77
+ image = load_image_from_path(data["image_path"])
78
+
79
+ return {
80
+ "image": image,
81
+ "problem": problem,
82
+ "solution": solution,
83
+ "original_question": data["question"],
84
+ "original_answer": data["answer"],
85
+ }
86
+ except Exception as e:
87
+ print(f"Error processing data {data}: {str(e)}")
88
+ return {
89
+ "image": None,
90
+ "problem": None,
91
+ "solution": None,
92
+ "original_question": None,
93
+ "original_answer": None,
94
+ }
95
+
96
+
97
+ raw_data_list = [
98
+ "/path/to/reasoning_data_with_response_90k_verified",
99
+ ]
100
+
101
+ raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102
+
103
+ processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104
+
105
+ hf_dict = {
106
+ "image": [],
107
+ "problem": [],
108
+ "solution": [],
109
+ "original_question": [],
110
+ "original_answer": [],
111
+ }
112
+
113
+ for item in tqdm(processed_data):
114
+ hf_dict["image"].append(item["image"])
115
+ hf_dict["problem"].append(item["problem"])
116
+ hf_dict["solution"].append(item["solution"])
117
+ hf_dict["original_question"].append(item["original_question"])
118
+ hf_dict["original_answer"].append(item["original_answer"])
119
+
120
+
121
+ features = datasets.Features(
122
+ {
123
+ "image": datasets.Image(),
124
+ "problem": datasets.Value("string"),
125
+ "solution": datasets.Value("string"),
126
+ "original_question": datasets.Value("string"),
127
+ "original_answer": datasets.Value("string"),
128
+ }
129
+ )
130
+
131
+
132
+ def has_empty_tags(text):
133
+ # Pattern to match empty tags like <tag></tag>
134
+ pattern = r"<[^>]+></[^>]+>"
135
+ return bool(re.search(pattern, text))
136
+
137
+
138
+ def has_answer_pattern(text):
139
+ if "Answer:" in text:
140
+ return True
141
+ return False
142
+
143
+
144
+ def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145
+ # Assuming the image is in a format that can be checked for dimensions
146
+ # You might need to adjust this depending on how the image is stored in your dataset
147
+ try:
148
+ image = example["image"] # or however your image is accessed
149
+ if isinstance(image, dict) and "height" in image and "width" in image:
150
+ return image["height"] >= 28 and image["width"] >= 28
151
+ # If image is a PIL Image or similar
152
+ return image.height >= 28 and image.width >= 28
153
+ except:
154
+ return False
155
+
156
+
157
+ ds = datasets.Dataset.from_dict(hf_dict, features=features)
158
+ ds = ds.filter(
159
+ lambda x: not has_empty_tags(x["solution"])
160
+ and not has_answer_pattern(x["problem"])
161
+ and has_valid_image_size(x)
162
+ and x["image"] is not None,
163
+ num_proc=128,
164
+ )
165
+ # Push to Hugging Face Hub
166
+ ds.push_to_hub("path/to/your/dataset")
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/train_aria_moe.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export NCCL_BLOCKING_WAIT=0
4
+ export TOKENIZERS_PARALLELISM=false
5
+ export OMP_NUM_THREADS=8
6
+ export NCCL_IB_DISABLE=0
7
+ export NCCL_IB_GID_INDEX=3
8
+ export NCCL_SOCKET_IFNAME=eth0
9
+ export NCCL_DEBUG=INFO
10
+
11
+ # CONFIG Huggingface
12
+ # export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
13
+ export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
14
+ export HF_HOME="$HOME/.cache/huggingface"
15
+ export HF_HUB_ENABLE_HF_TRANSFER="1"
16
+
17
+ export NCCL_DEBUG=INFO
18
+
19
+ GPUS="0,1,2,3,4,5,6,7"
20
+
21
+ # 取 worker0 第一个 port
22
+ ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23
+ port=${ports[0]}
24
+ port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25
+
26
+ echo "total workers: ${ARNOLD_WORKER_NUM}"
27
+ echo "cur worker id: ${ARNOLD_ID}"
28
+ echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29
+ echo "master ip: ${METIS_WORKER_0_HOST}"
30
+ echo "master port: ${port}"
31
+ echo "master port in cmd: ${port_in_cmd}"
32
+
33
+ # export WANDB_BASE_URL=https://api.wandb.ai
34
+ # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
35
+ # wandb login $WANDB_API_KEY
36
+
37
+ export WANDB_BASE_URL=https://api.wandb.ai
38
+ export WANDB_PROJECT=vision-reasoning
39
+ export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
40
+ export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41
+ wandb login $WANDB_API_KEY
42
+
43
+ cd /home/tiger/multimodal-open-r1
44
+ # pip3 install vllm==0.6.6.post1
45
+ pip3 install -e ".[dev]"
46
+ pip3 install wandb==0.18.3
47
+
48
+ torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49
+ --nnodes="${ARNOLD_WORKER_NUM}" \
50
+ --node_rank="${ARNOLD_ID}" \
51
+ --master_addr="${METIS_WORKER_0_HOST}" \
52
+ --master_port="${port_in_cmd}" \
53
+ src/open_r1/grpo.py \
54
+ --deepspeed scripts/zero3.json \
55
+ --output_dir Aria-GRPO-mini_cot_80k \
56
+ --model_name_or_path rhymes-ai/Aria \
57
+ --dataset_name luodian/mini_cot_80k \
58
+ --max_prompt_length 8192 \
59
+ --per_device_train_batch_size 1 \
60
+ --gradient_accumulation_steps 1 \
61
+ --logging_steps 1 \
62
+ --bf16 \
63
+ --report_to wandb \
64
+ --gradient_checkpointing true \
65
+ --attn_implementation eager \
66
+ --save_total_limit 8 \
67
+ --num_train_epochs 1 \
68
+ --run_name $WANDB_RUN_NAME
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/train_qwen2_vl.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export NCCL_BLOCKING_WAIT=0
4
+ export TOKENIZERS_PARALLELISM=false
5
+ export OMP_NUM_THREADS=8
6
+ export NCCL_IB_DISABLE=0
7
+ export NCCL_IB_GID_INDEX=3
8
+ export NCCL_SOCKET_IFNAME=eth0
9
+ export NCCL_DEBUG=INFO
10
+
11
+ GPUS="0,1,2,3,4,5,6,7"
12
+
13
+ # 取 worker0 第一个 port
14
+ ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15
+ port=${ports[0]}
16
+ port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17
+
18
+ echo "total workers: ${ARNOLD_WORKER_NUM}"
19
+ echo "cur worker id: ${ARNOLD_ID}"
20
+ echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21
+ echo "master ip: ${METIS_WORKER_0_HOST}"
22
+ echo "master port: ${port}"
23
+ echo "master port in cmd: ${port_in_cmd}"
24
+
25
+ # export WANDB_BASE_URL=https://api.wandb.ai
26
+ # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
27
+ # wandb login $WANDB_API_KEY
28
+
29
+ export WANDB_BASE_URL=https://api.wandb.ai
30
+ export WANDB_PROJECT=vision-reasoning
31
+ export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
32
+ export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33
+ wandb login $WANDB_API_KEY
34
+
35
+ cd /home/tiger/multimodal-open-r1
36
+ # pip3 install vllm==0.6.6.post1
37
+ pip3 install -e ".[dev]"
38
+ pip3 install wandb==0.18.3
39
+
40
+ torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41
+ --nnodes="${ARNOLD_WORKER_NUM}" \
42
+ --node_rank="${ARNOLD_ID}" \
43
+ --master_addr="${METIS_WORKER_0_HOST}" \
44
+ --master_port="${port_in_cmd}" \
45
+ src/open_r1/grpo.py \
46
+ --deepspeed scripts/zero3.json \
47
+ --output_dir checkpoints/${WANDB_RUN_NAME} \
48
+ --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49
+ --dataset_name luodian/${DATASET_NAME} \
50
+ --max_prompt_length 8192 \
51
+ --per_device_train_batch_size 1 \
52
+ --gradient_accumulation_steps 1 \
53
+ --logging_steps 1 \
54
+ --bf16 \
55
+ --report_to wandb \
56
+ --gradient_checkpointing true \
57
+ --attn_implementation flash_attention_2 \
58
+ --max_pixels 2359296 \
59
+ --save_total_limit 8 \
60
+ --num_train_epochs 1 \
61
+ --run_name $WANDB_RUN_NAME
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero2.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "offload_optimizer": {
25
+ "device": "none",
26
+ "pin_memory": true
27
+ },
28
+ "allgather_partitions": true,
29
+ "allgather_bucket_size": 2e8,
30
+ "overlap_comm": false,
31
+ "reduce_scatter": true,
32
+ "reduce_bucket_size": 2e8,
33
+ "contiguous_gradients": true
34
+ },
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero3.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "none",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "none",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": true
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
previous_version/Video-R1-main-previous/src/r1-v/local_scripts/zero3_offload.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 3,
24
+ "offload_optimizer": {
25
+ "device": "cpu",
26
+ "pin_memory": true
27
+ },
28
+ "offload_param": {
29
+ "device": "cpu",
30
+ "pin_memory": true
31
+ },
32
+ "overlap_comm": true,
33
+ "contiguous_gradients": true,
34
+ "sub_group_size": 1e9,
35
+ "reduce_bucket_size": "auto",
36
+ "stage3_prefetch_bucket_size": "auto",
37
+ "stage3_param_persistence_threshold": "auto",
38
+ "stage3_max_live_parameters": 1e9,
39
+ "stage3_max_reuse_distance": 1e9,
40
+ "gather_16bit_weights_on_model_save": true
41
+ },
42
+ "gradient_accumulation_steps": "auto",
43
+ "gradient_clipping": "auto",
44
+ "train_batch_size": "auto",
45
+ "train_micro_batch_size_per_gpu": "auto",
46
+ "steps_per_print": 1e5,
47
+ "wall_clock_breakdown": false
48
+ }
previous_version/Video-R1-main-previous/src/r1-v/run_grpo.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/r1-v
2
+
3
+ export DEBUG_MODE="true"
4
+ export LOG_PATH="./debug_log_2b.txt"
5
+
6
+
7
+
8
+ torchrun --nproc_per_node="8" \
9
+ --nnodes="1" \
10
+ --node_rank="0" \
11
+ --master_addr="127.0.0.1" \
12
+ --master_port="12345" \
13
+ src/open_r1/grpo.py \
14
+ --output_dir <OUTPUT_DIR> \
15
+ --model_name_or_path <PATH-TO-Qwen2-VL-2B-Instruct> \
16
+ --dataset_name <PATH-TO-DATASET> \
17
+ --max_prompt_length 1024 \
18
+ --per_device_train_batch_size 1 \
19
+ --gradient_accumulation_steps 2 \
20
+ --logging_steps 1 \
21
+ --bf16 \
22
+ --report_to wandb \
23
+ --gradient_checkpointing false \
24
+ --attn_implementation flash_attention_2 \
25
+ --max_pixels 401408 \
26
+ --num_train_epochs 2 \
27
+ --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \
28
+ --save_steps 100 \
29
+ --save_only_model true
previous_version/Video-R1-main-previous/src/r1-v/setup.cfg ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ default_section = FIRSTPARTY
3
+ ensure_newline_before_comments = True
4
+ force_grid_wrap = 0
5
+ include_trailing_comma = True
6
+ known_first_party = open_r1
7
+ known_third_party =
8
+ transformers
9
+ datasets
10
+ fugashi
11
+ git
12
+ h5py
13
+ matplotlib
14
+ nltk
15
+ numpy
16
+ packaging
17
+ pandas
18
+ psutil
19
+ pytest
20
+ rouge_score
21
+ sacrebleu
22
+ seqeval
23
+ sklearn
24
+ streamlit
25
+ torch
26
+ tqdm
27
+
28
+ line_length = 119
29
+ lines_after_imports = 2
30
+ multi_line_output = 3
31
+ use_parentheses = True
32
+
33
+ [flake8]
34
+ ignore = E203, E501, E741, W503, W605
35
+ max-line-length = 119
36
+ per-file-ignores =
37
+ # imported but unused
38
+ __init__.py: F401
39
+
40
+ [tool:pytest]
41
+ doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
previous_version/Video-R1-main-previous/src/r1-v/setup.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16
+
17
+
18
+ import re
19
+ import shutil
20
+ from pathlib import Path
21
+
22
+ from setuptools import find_packages, setup
23
+
24
+
25
+ # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26
+ stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27
+ if stale_egg_info.exists():
28
+ print(
29
+ (
30
+ "Warning: {} exists.\n\n"
31
+ "If you recently updated open_r1, this is expected,\n"
32
+ "but it may prevent open_r1 from installing in editable mode.\n\n"
33
+ "This directory is automatically generated by Python's packaging tools.\n"
34
+ "I will remove it now.\n\n"
35
+ "See https://github.com/pypa/pip/issues/5466 for details.\n"
36
+ ).format(stale_egg_info)
37
+ )
38
+ shutil.rmtree(stale_egg_info)
39
+
40
+
41
+ # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42
+ # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43
+ _deps = [
44
+ "accelerate>=1.2.1",
45
+ "bitsandbytes>=0.43.0",
46
+ "black>=24.4.2",
47
+ "datasets>=3.2.0",
48
+ "deepspeed==0.15.4",
49
+ "distilabel[vllm,ray,openai]>=1.5.2",
50
+ "einops>=0.8.0",
51
+ "flake8>=6.0.0",
52
+ "hf_transfer>=0.1.4",
53
+ "huggingface-hub[cli]>=0.19.2,<1.0",
54
+ "isort>=5.12.0",
55
+ "liger_kernel==0.5.2",
56
+ "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57
+ "math-verify", # Used for math verification in grpo
58
+ "packaging>=23.0",
59
+ "parameterized>=0.9.0",
60
+ "pytest",
61
+ "safetensors>=0.3.3",
62
+ "sentencepiece>=0.1.99",
63
+ "torch>=2.5.1",
64
+ "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef",
65
+ "trl==0.14.0",
66
+ "vllm==0.6.6.post1",
67
+ "wandb>=0.19.1",
68
+ "pillow",
69
+ ]
70
+
71
+ # this is a lookup table with items like:
72
+ #
73
+ # tokenizers: "tokenizers==0.9.4"
74
+ # packaging: "packaging"
75
+ #
76
+ # some of the values are versioned whereas others aren't.
77
+ deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78
+
79
+
80
+ def deps_list(*pkgs):
81
+ return [deps[pkg] for pkg in pkgs]
82
+
83
+
84
+ extras = {}
85
+ extras["tests"] = deps_list("pytest", "parameterized")
86
+ extras["torch"] = deps_list("torch")
87
+ extras["quality"] = deps_list("black", "isort", "flake8")
88
+ extras["eval"] = deps_list("lighteval", "math-verify")
89
+ extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
90
+
91
+ # core dependencies shared across the whole project - keep this to a bare minimum :)
92
+ install_requires = [
93
+ deps["accelerate"],
94
+ deps["bitsandbytes"],
95
+ deps["einops"],
96
+ deps["datasets"],
97
+ deps["deepspeed"],
98
+ deps["hf_transfer"],
99
+ deps["huggingface-hub"],
100
+ deps["liger_kernel"],
101
+ deps["packaging"], # utilities from PyPA to e.g., compare versions
102
+ deps["safetensors"],
103
+ deps["sentencepiece"],
104
+ deps["transformers"],
105
+ deps["trl"],
106
+ ]
107
+
108
+ setup(
109
+ name="r1-v",
110
+ version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
111
+ author="The r1-v team and the Hugging Face team (past and future)",
112
+ description="R1-V",
113
+ license="Apache",
114
+ url="https://github.com/Deep-Agent/R1-V",
115
+ package_dir={"": "src"},
116
+ packages=find_packages("src"),
117
+ zip_safe=False,
118
+ extras_require=extras,
119
+ python_requires=">=3.10.9",
120
+ install_requires=install_requires,
121
+ classifiers=[
122
+ "Development Status :: 3 - Alpha",
123
+ "Intended Audience :: Developers",
124
+ "Intended Audience :: Education",
125
+ "Intended Audience :: Science/Research",
126
+ "License :: OSI Approved :: Apache Software License",
127
+ "Operating System :: OS Independent",
128
+ "Programming Language :: Python :: 3",
129
+ "Programming Language :: Python :: 3.10",
130
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
131
+ ],
132
+ )
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/__init__.py ADDED
File without changes
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/evaluate.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Custom evaluation tasks for LightEval."""
16
+
17
+ from lighteval.metrics.dynamic_metrics import (
18
+ ExprExtractionConfig,
19
+ LatexExtractionConfig,
20
+ multilingual_extractive_match_metric,
21
+ )
22
+ from lighteval.tasks.lighteval_task import LightevalTaskConfig
23
+ from lighteval.tasks.requests import Doc
24
+ from lighteval.utils.language import Language
25
+
26
+
27
+ metric = multilingual_extractive_match_metric(
28
+ language=Language.ENGLISH,
29
+ fallback_mode="first_match",
30
+ precision=5,
31
+ gold_extraction_target=(LatexExtractionConfig(),),
32
+ pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33
+ aggregation_function=max,
34
+ )
35
+
36
+
37
+ def prompt_fn(line, task_name: str = None):
38
+ """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39
+ return Doc(
40
+ task_name=task_name,
41
+ query=line["problem"],
42
+ choices=[line["solution"]],
43
+ gold_index=0,
44
+ )
45
+
46
+
47
+ # Define tasks
48
+ aime24 = LightevalTaskConfig(
49
+ name="aime24",
50
+ suite=["custom"],
51
+ prompt_function=prompt_fn,
52
+ hf_repo="HuggingFaceH4/aime_2024",
53
+ hf_subset="default",
54
+ hf_avail_splits=["train"],
55
+ evaluation_splits=["train"],
56
+ few_shots_split=None,
57
+ few_shots_select=None,
58
+ generation_size=32768,
59
+ metric=[metric],
60
+ version=1,
61
+ )
62
+ math_500 = LightevalTaskConfig(
63
+ name="math_500",
64
+ suite=["custom"],
65
+ prompt_function=prompt_fn,
66
+ hf_repo="HuggingFaceH4/MATH-500",
67
+ hf_subset="default",
68
+ hf_avail_splits=["test"],
69
+ evaluation_splits=["test"],
70
+ few_shots_split=None,
71
+ few_shots_select=None,
72
+ generation_size=32768,
73
+ metric=[metric],
74
+ version=1,
75
+ )
76
+
77
+ # Add tasks to the table
78
+ TASKS_TABLE = []
79
+ TASKS_TABLE.append(aime24)
80
+ TASKS_TABLE.append(math_500)
81
+
82
+ # MODULE LOGIC
83
+ if __name__ == "__main__":
84
+ print([t["name"] for t in TASKS_TABLE])
85
+ print(len(TASKS_TABLE))
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/generate.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from distilabel.llms import OpenAILLM
18
+ from distilabel.pipeline import Pipeline
19
+ from distilabel.steps.tasks import TextGeneration
20
+
21
+
22
+ def build_distilabel_pipeline(
23
+ model: str,
24
+ base_url: str = "http://localhost:8000/v1",
25
+ prompt_column: Optional[str] = None,
26
+ temperature: Optional[float] = None,
27
+ top_p: Optional[float] = None,
28
+ max_new_tokens: int = 8192,
29
+ num_generations: int = 1,
30
+ ) -> Pipeline:
31
+ generation_kwargs = {"max_new_tokens": max_new_tokens}
32
+
33
+ if temperature is not None:
34
+ generation_kwargs["temperature"] = temperature
35
+
36
+ if top_p is not None:
37
+ generation_kwargs["top_p"] = top_p
38
+
39
+ with Pipeline().ray() as pipeline:
40
+ TextGeneration(
41
+ llm=OpenAILLM(
42
+ base_url=base_url,
43
+ api_key="something",
44
+ model=model,
45
+ # thinking can take some time...
46
+ timeout=10 * 60,
47
+ generation_kwargs=generation_kwargs,
48
+ ),
49
+ input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50
+ input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51
+ num_generations=num_generations,
52
+ )
53
+
54
+ return pipeline
55
+
56
+
57
+ if __name__ == "__main__":
58
+ import argparse
59
+
60
+ from datasets import load_dataset
61
+
62
+ parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63
+ parser.add_argument(
64
+ "--hf-dataset",
65
+ type=str,
66
+ required=True,
67
+ help="HuggingFace dataset to load",
68
+ )
69
+ parser.add_argument(
70
+ "--hf-dataset-config",
71
+ type=str,
72
+ required=False,
73
+ help="Dataset config to use",
74
+ )
75
+ parser.add_argument(
76
+ "--hf-dataset-split",
77
+ type=str,
78
+ default="train",
79
+ help="Dataset split to use",
80
+ )
81
+ parser.add_argument("--prompt-column", type=str, default="prompt")
82
+ parser.add_argument(
83
+ "--model",
84
+ type=str,
85
+ required=True,
86
+ help="Model name to use for generation",
87
+ )
88
+ parser.add_argument(
89
+ "--vllm-server-url",
90
+ type=str,
91
+ default="http://localhost:8000/v1",
92
+ help="URL of the vLLM server",
93
+ )
94
+ parser.add_argument(
95
+ "--temperature",
96
+ type=float,
97
+ help="Temperature for generation",
98
+ )
99
+ parser.add_argument(
100
+ "--top-p",
101
+ type=float,
102
+ help="Top-p value for generation",
103
+ )
104
+ parser.add_argument(
105
+ "--max-new-tokens",
106
+ type=int,
107
+ default=8192,
108
+ help="Maximum number of new tokens to generate",
109
+ )
110
+ parser.add_argument(
111
+ "--num-generations",
112
+ type=int,
113
+ default=1,
114
+ help="Number of generations per problem",
115
+ )
116
+ parser.add_argument(
117
+ "--hf-output-dataset",
118
+ type=str,
119
+ required=False,
120
+ help="HuggingFace repo to push results to",
121
+ )
122
+ parser.add_argument(
123
+ "--private",
124
+ action="store_true",
125
+ help="Whether to make the output dataset private when pushing to HF Hub",
126
+ )
127
+
128
+ args = parser.parse_args()
129
+
130
+ print("\nRunning with arguments:")
131
+ for arg, value in vars(args).items():
132
+ print(f" {arg}: {value}")
133
+ print()
134
+
135
+ print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136
+ dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137
+ print("Dataset loaded!")
138
+
139
+ pipeline = build_distilabel_pipeline(
140
+ model=args.model,
141
+ base_url=args.vllm_server_url,
142
+ prompt_column=args.prompt_column,
143
+ temperature=args.temperature,
144
+ top_p=args.top_p,
145
+ max_new_tokens=args.max_new_tokens,
146
+ num_generations=args.num_generations,
147
+ )
148
+
149
+ print("Running generation pipeline...")
150
+ distiset = pipeline.run(dataset=dataset, use_cache=False)
151
+ print("Generation pipeline finished!")
152
+
153
+ if args.hf_output_dataset:
154
+ print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155
+ distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156
+ print("Dataset pushed!")
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/grpo.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ from datasets import load_dataset, load_from_disk
22
+ from transformers import Qwen2VLForConditionalGeneration
23
+
24
+ from math_verify import parse, verify
25
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainer
26
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
27
+
28
+ from datasets import Dataset, DatasetDict
29
+
30
+
31
+ @dataclass
32
+ class GRPOScriptArguments(ScriptArguments):
33
+ """
34
+ Script arguments for the GRPO training script.
35
+
36
+ Args:
37
+ reward_funcs (`list[str]`):
38
+ List of reward functions. Possible values: 'accuracy', 'format'.
39
+ """
40
+
41
+ reward_funcs: list[str] = field(
42
+ default_factory=lambda: ["accuracy", "format"],
43
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
44
+ )
45
+ max_pixels: Optional[int] = field(
46
+ default=12845056,
47
+ metadata={"help": "Maximum number of pixels for the image"},
48
+ )
49
+ min_pixels: Optional[int] = field(
50
+ default=3136,
51
+ metadata={"help": "Minimum number of pixels for the image"},
52
+ )
53
+
54
+
55
+ def accuracy_reward(completions, solution, **kwargs):
56
+ """Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
57
+ contents = [completion[0]["content"] for completion in completions]
58
+ rewards = []
59
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
60
+ for content, sol in zip(contents, solution):
61
+ reward = 0.0
62
+ # Try symbolic verification first
63
+ try:
64
+ answer = parse(content)
65
+ if float(verify(answer, parse(sol))) > 0:
66
+ reward = 1.0
67
+ except Exception:
68
+ pass # Continue to next verification method if this fails
69
+
70
+ # If symbolic verification failed, try string matching
71
+ if reward == 0.0:
72
+ try:
73
+ # Extract answer from solution if it has think/answer tags
74
+ sol_match = re.search(r'<answer>(.*?)</answer>', sol)
75
+ ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
76
+
77
+ # Extract answer from content if it has think/answer tags
78
+ content_match = re.search(r'<answer>(.*?)</answer>', content)
79
+ student_answer = content_match.group(1).strip() if content_match else content.strip()
80
+
81
+ # Compare the extracted answers
82
+ if student_answer == ground_truth:
83
+ reward = 1.0
84
+ except Exception:
85
+ pass # Keep reward as 0.0 if both methods fail
86
+
87
+ rewards.append(reward)
88
+ if os.getenv("DEBUG_MODE") == "true":
89
+ log_path = os.getenv("LOG_PATH")
90
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
91
+ with open(log_path, "a") as f:
92
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
93
+ f.write(f"Content: {content}\n")
94
+ f.write(f"Solution: {sol}\n")
95
+ return rewards
96
+
97
+
98
+ def format_reward(completions, **kwargs):
99
+ """Reward function that checks if the completion has a specific format."""
100
+ pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
101
+ completion_contents = [completion[0]["content"] for completion in completions]
102
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
103
+ return [1.0 if match else 0.0 for match in matches]
104
+
105
+
106
+ reward_funcs_registry = {
107
+ "accuracy": accuracy_reward,
108
+ "format": format_reward,
109
+ }
110
+
111
+ SYSTEM_PROMPT = (
112
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
113
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
114
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
115
+ "<think> reasoning process here </think><answer> answer here </answer>"
116
+ )
117
+
118
+
119
+ def main(script_args, training_args, model_args):
120
+ # Get reward functions
121
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
122
+
123
+ if script_args.dataset_name[-6:] == '.jsonl':
124
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
125
+ else:
126
+ # Load the dataset
127
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
128
+
129
+
130
+ # Format into conversation
131
+ def make_conversation(example):
132
+ return {
133
+ "prompt": [
134
+ {"role": "system", "content": SYSTEM_PROMPT},
135
+ {"role": "user", "content": example["problem"]},
136
+ ],
137
+ }
138
+
139
+ # def make_conversation_image(example):
140
+ # return {
141
+ # "prompt": [
142
+ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
143
+ # {
144
+ # "role": "user",
145
+ # "content": [
146
+ # {"type": "image"},
147
+ # {"type": "text", "text": example["problem"]},
148
+ # ],
149
+ # },
150
+ # ],
151
+ # }
152
+
153
+ QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
154
+
155
+ def make_conversation_image(example):
156
+
157
+ return {
158
+ "prompt": [
159
+ {
160
+ "role": "user",
161
+ "content": [
162
+ {"type": "image"},
163
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
164
+ ],
165
+ },
166
+ ],
167
+ }
168
+
169
+
170
+ def make_conversation_video(example):
171
+ return {
172
+ "prompt": [
173
+ {
174
+ "role": "user",
175
+ "content": [
176
+ {"type": "video"},
177
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
178
+ ],
179
+ },
180
+ ],
181
+ }
182
+
183
+
184
+ if "image" in dataset[script_args.dataset_train_split].features:
185
+ print("has image in dataset")
186
+ dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
187
+ # dataset = dataset.remove_columns(["original_question", "original_answer"])
188
+
189
+ elif "video_filename" in dataset[script_args.dataset_train_split].features:
190
+ print("has video in dataset")
191
+ dataset = dataset.map(make_conversation_video)
192
+
193
+ else:
194
+ print("no image in dataset")
195
+ dataset = dataset.map(make_conversation)
196
+ dataset = dataset.remove_columns("messages")
197
+
198
+ # import pdb
199
+ # pdb.set_trace()
200
+
201
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainer
202
+ print("using: ", trainer_cls)
203
+
204
+ # Initialize the GRPO trainer
205
+ trainer = trainer_cls(
206
+ model=model_args.model_name_or_path,
207
+ reward_funcs=reward_funcs,
208
+ args=training_args,
209
+ train_dataset=dataset[script_args.dataset_train_split],
210
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
211
+ peft_config=get_peft_config(model_args),
212
+ attn_implementation=model_args.attn_implementation,
213
+ max_pixels=script_args.max_pixels,
214
+ min_pixels=script_args.min_pixels,
215
+ )
216
+
217
+ # Train and push the model to the Hub
218
+ trainer.train()
219
+
220
+ # Save and push to hub
221
+ trainer.save_model(training_args.output_dir)
222
+ if training_args.push_to_hub:
223
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
228
+ script_args, training_args, model_args = parser.parse_args_and_config()
229
+ main(script_args, training_args, model_args)
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/sft.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Supervised fine-tuning script for decoder language models.
17
+
18
+ Usage:
19
+
20
+ # One 1 node of 8 x H100s
21
+ accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
22
+ --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
23
+ --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
24
+ --learning_rate 2.0e-5 \
25
+ --num_train_epochs 1 \
26
+ --packing \
27
+ --max_seq_length 4096 \
28
+ --per_device_train_batch_size 4 \
29
+ --gradient_accumulation_steps 4 \
30
+ --gradient_checkpointing \
31
+ --bf16 \
32
+ --logging_steps 5 \
33
+ --eval_strategy steps \
34
+ --eval_steps 100 \
35
+ --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
36
+ """
37
+
38
+ import logging
39
+ import os
40
+ import sys
41
+
42
+ import datasets
43
+ from dataclasses import dataclass, field
44
+ from typing import Optional
45
+ import torch
46
+ import transformers
47
+ from datasets import load_dataset
48
+ from transformers import AutoTokenizer, set_seed, AutoProcessor
49
+ from transformers.trainer_utils import get_last_checkpoint
50
+ import trl
51
+ from trl import (
52
+ ModelConfig,
53
+ ScriptArguments,
54
+ SFTTrainer,
55
+ TrlParser,
56
+ get_kbit_device_map,
57
+ get_peft_config,
58
+ get_quantization_config,
59
+ )
60
+
61
+ from qwen_vl_utils import process_vision_info
62
+ logger = logging.getLogger(__name__)
63
+
64
+
65
+ @dataclass
66
+ class SFTConfig(trl.SFTConfig):
67
+ """
68
+ args for callbacks, benchmarks etc
69
+ """
70
+
71
+ benchmarks: list[str] = field(
72
+ default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
73
+ )
74
+ callbacks: list[str] = field(
75
+ default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
76
+ )
77
+ system_prompt: Optional[str] = field(
78
+ default=None,
79
+ metadata={"help": "The optional system prompt to use for benchmarking."},
80
+ )
81
+ hub_model_revision: Optional[str] = field(
82
+ default="main",
83
+ metadata={"help": "The Hub model branch to push the model to."},
84
+ )
85
+ overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
86
+ push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
87
+
88
+
89
+
90
+ processor = None
91
+
92
+
93
+ def convert_example(example):
94
+ """
95
+ correct example into "messages"
96
+ eg:
97
+ {
98
+ "system": "You are a helpful assistant.",
99
+ "conversations": [
100
+ {"from": "user", "value": "How many objects are included in this image?",
101
+ "image_path": "/path/to/image.png"},
102
+ {"from": "assistant", "value": "<think>\nI can see 10 objects\n</think>\n<answer>\n10\n</answer>"}
103
+ ]
104
+ }
105
+ """
106
+ messages = []
107
+ if "system" in example:
108
+ messages.append({
109
+ "role": "system",
110
+ "content": [{"type": "text", "text": example["system"]}],
111
+ })
112
+ else:
113
+ SYSTEM_PROMPT = (
114
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
115
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
116
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
117
+ "<think> reasoning process here </think><answer> answer here </answer>"
118
+ )
119
+ messages.append({
120
+ "role": "system",
121
+ "content": [{"type": "text", "text": SYSTEM_PROMPT}],
122
+ })
123
+
124
+ thinking = example.get("thinking")
125
+ problem = example.get("problem")
126
+ solution = example.get("solution")
127
+ image = example.get("image")
128
+ messages.append({
129
+ "role": "user",
130
+ "content": [
131
+ {"type": "text", "text": problem},
132
+ {"type": "image", "image": image},
133
+ ]
134
+ })
135
+ messages.append({
136
+ "role": "assistant",
137
+ "content": f"{thinking}\n\n{solution}",
138
+ })
139
+
140
+ example["messages"] = messages
141
+ return example
142
+
143
+
144
+ def collate_fn(examples):
145
+ texts = [
146
+ processor.apply_chat_template( convert_example(example)["messages"], tokenize=False, add_generation_prompt=True)
147
+ for example in examples
148
+ ]
149
+ image_inputs = []
150
+ for example in examples:
151
+ imgs, vids = process_vision_info(example["messages"])
152
+ image_inputs.append(imgs)
153
+ batch = processor(
154
+ text=texts,
155
+ images=image_inputs,
156
+ return_tensors="pt",
157
+ padding=True,
158
+ )
159
+ labels = batch["input_ids"].clone()
160
+ labels[labels == processor.tokenizer.pad_token_id] = -100
161
+ image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
162
+ labels[labels == image_token_id] = -100
163
+ batch["labels"] = labels
164
+
165
+ # print(batch)
166
+
167
+ return batch
168
+
169
+
170
+ def main(script_args, training_args, model_args):
171
+ # Set seed for reproducibility
172
+ set_seed(training_args.seed)
173
+
174
+ ###############
175
+ # Setup logging
176
+ ###############
177
+ logging.basicConfig(
178
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
179
+ datefmt="%Y-%m-%d %H:%M:%S",
180
+ handlers=[logging.StreamHandler(sys.stdout)],
181
+ )
182
+ log_level = training_args.get_process_log_level()
183
+ logger.setLevel(log_level)
184
+ datasets.utils.logging.set_verbosity(log_level)
185
+ transformers.utils.logging.set_verbosity(log_level)
186
+ transformers.utils.logging.enable_default_handler()
187
+ transformers.utils.logging.enable_explicit_format()
188
+
189
+ # Log on each process a small summary
190
+ logger.warning(
191
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
192
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
193
+ )
194
+ logger.info(f"Model parameters {model_args}")
195
+ logger.info(f"Script parameters {script_args}")
196
+ logger.info(f"Data parameters {training_args}")
197
+
198
+ # Check for last checkpoint
199
+ last_checkpoint = None
200
+ if os.path.isdir(training_args.output_dir):
201
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
202
+ if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
203
+ logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
204
+
205
+ ################
206
+ # Load datasets
207
+ ################
208
+
209
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
210
+
211
+ ################
212
+ # Load tokenizer
213
+ ################
214
+ global processor
215
+ if "vl" in model_args.model_name_or_path.lower():
216
+ processor = AutoProcessor.from_pretrained(
217
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
218
+ )
219
+ logger.info("Using AutoProcessor for vision-language model.")
220
+ else:
221
+ processor = AutoTokenizer.from_pretrained(
222
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
223
+ )
224
+ logger.info("Using AutoTokenizer for text-only model.")
225
+ if hasattr(processor, "pad_token") and processor.pad_token is None:
226
+ processor.pad_token = processor.eos_token
227
+ elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
228
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
229
+
230
+ ###################
231
+ # Model init kwargs
232
+ ###################
233
+ logger.info("*** Initializing model kwargs ***")
234
+ torch_dtype = (
235
+ model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
236
+ )
237
+ quantization_config = get_quantization_config(model_args)
238
+ model_kwargs = dict(
239
+ revision=model_args.model_revision,
240
+ trust_remote_code=model_args.trust_remote_code,
241
+ attn_implementation=model_args.attn_implementation,
242
+ torch_dtype=torch_dtype,
243
+ use_cache=False if training_args.gradient_checkpointing else True,
244
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
245
+ quantization_config=quantization_config,
246
+ )
247
+ # training_args.model_init_kwargs = model_kwargs
248
+ from transformers import Qwen2VLForConditionalGeneration
249
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
250
+ model_args.model_name_or_path, **model_kwargs
251
+ )
252
+ ############################
253
+ # Initialize the SFT Trainer
254
+ ############################
255
+ training_args.dataset_kwargs = {
256
+ "skip_prepare_dataset": True,
257
+ }
258
+ training_args.remove_unused_columns = False
259
+
260
+
261
+ trainer = SFTTrainer(
262
+ model=model,
263
+ args=training_args,
264
+ train_dataset=dataset[script_args.dataset_train_split],
265
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
266
+ processing_class=processor.tokenizer,
267
+ data_collator=collate_fn,
268
+ peft_config=get_peft_config(model_args)
269
+ )
270
+
271
+
272
+
273
+ ###############
274
+ # Training loop
275
+ ###############
276
+ logger.info("*** Train ***")
277
+ checkpoint = None
278
+ if training_args.resume_from_checkpoint is not None:
279
+ checkpoint = training_args.resume_from_checkpoint
280
+ elif last_checkpoint is not None:
281
+ checkpoint = last_checkpoint
282
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
283
+ metrics = train_result.metrics
284
+ metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
285
+ trainer.log_metrics("train", metrics)
286
+ trainer.save_metrics("train", metrics)
287
+ trainer.save_state()
288
+
289
+ ##################################
290
+ # Save model and create model card
291
+ ##################################
292
+ logger.info("*** Save model ***")
293
+ trainer.save_model(training_args.output_dir)
294
+ processor.save_pretrained(training_args.output_dir)
295
+ logger.info(f"Model saved to {training_args.output_dir}")
296
+
297
+ # Save everything else on main process
298
+ kwargs = {
299
+ "dataset_name": script_args.dataset_name,
300
+ "tags": ["R1-V"],
301
+ }
302
+ if trainer.accelerator.is_main_process:
303
+ trainer.create_model_card(**kwargs)
304
+ # Restore k,v cache for fast inference
305
+ trainer.model.config.use_cache = True
306
+ trainer.model.config.save_pretrained(training_args.output_dir)
307
+ #############
308
+ # push to hub
309
+ #############
310
+
311
+ if training_args.push_to_hub:
312
+ logger.info("Pushing to hub...")
313
+ trainer.push_to_hub(**kwargs)
314
+ processor.push_to_hub(training_args.hub_model_id)
315
+
316
+
317
+
318
+
319
+ if __name__ == "__main__":
320
+ parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
321
+ script_args, training_args, model_args = parser.parse_args_and_config()
322
+ main(script_args, training_args, model_args)
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/trainer/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .grpo_trainer import Qwen2VLGRPOTrainer
2
+ from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
3
+
4
+ __all__ = ["Qwen2VLGRPOTrainer", "Qwen2VLGRPOVLLMTrainer"]
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/trainer/grpo_trainer.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+
20
+ import torch
21
+ import torch.utils.data
22
+ import transformers
23
+ from datasets import Dataset, IterableDataset
24
+ from packaging import version
25
+ from transformers import (
26
+ AriaForConditionalGeneration,
27
+ AriaProcessor,
28
+ AutoModelForCausalLM,
29
+ AutoModelForSequenceClassification,
30
+ AutoProcessor,
31
+ AutoTokenizer,
32
+ GenerationConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizerBase,
35
+ Qwen2VLForConditionalGeneration,
36
+ Qwen2_5_VLForConditionalGeneration,
37
+ Trainer,
38
+ TrainerCallback,
39
+ is_wandb_available,
40
+ )
41
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
42
+ from transformers.utils import is_peft_available
43
+
44
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
45
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
46
+ from trl.trainer.grpo_config import GRPOConfig
47
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
48
+
49
+ from qwen_vl_utils import process_vision_info
50
+
51
+ import copy
52
+
53
+
54
+ if is_peft_available():
55
+ from peft import PeftConfig, get_peft_model
56
+
57
+ if is_wandb_available():
58
+ import wandb
59
+
60
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
61
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
62
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
63
+
64
+
65
+ class Qwen2VLGRPOTrainer(Trainer):
66
+ """
67
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
68
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
69
+
70
+ Example:
71
+
72
+ ```python
73
+ from datasets import load_dataset
74
+ from trl import GRPOTrainer
75
+
76
+ dataset = load_dataset("trl-lib/tldr", split="train")
77
+
78
+ trainer = GRPOTrainer(
79
+ model="Qwen/Qwen2-0.5B-Instruct",
80
+ reward_funcs="weqweasdas/RM-Gemma-2B",
81
+ train_dataset=dataset,
82
+ )
83
+
84
+ trainer.train()
85
+ ```
86
+
87
+ Args:
88
+ model (`Union[str, PreTrainedModel]`):
89
+ Model to be trained. Can be either:
90
+
91
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
92
+ a path to a *directory* containing model weights saved using
93
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
94
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
95
+ in `args.model_init_kwargs`.
96
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
97
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
98
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
99
+ functions with the prompts and completions and sum the rewards. Can be either:
100
+
101
+ - A single reward function, such as:
102
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
103
+ path to a *directory* containing model weights saved using
104
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
105
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
106
+ keyword arguments in `args.model_init_kwargs`.
107
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
108
+ - A custom reward function: The function is provided with the prompts and the generated completions,
109
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
110
+ [Using a custom reward function](#using-a-custom-reward-function).
111
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
112
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
113
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
114
+ Configuration for this trainer. If `None`, a default configuration is used.
115
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
116
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
117
+ ignored. The format of the samples can be either:
118
+
119
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
120
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
121
+ and content).
122
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
123
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
124
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
125
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
126
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
127
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
128
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
129
+
130
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
131
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
132
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
133
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
134
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
135
+ the corresponding entries in `reward_processing_classes` are ignored.
136
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
137
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
138
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
139
+
140
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
141
+ method.
142
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
143
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
144
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
145
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
146
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ model: Union[str, PreTrainedModel],
152
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
153
+ args: GRPOConfig = None,
154
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
155
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
156
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
157
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
158
+ callbacks: Optional[list[TrainerCallback]] = None,
159
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
160
+ peft_config: Optional["PeftConfig"] = None,
161
+ max_pixels: Optional[int] = 12845056,
162
+ min_pixels: Optional[int] = 3136,
163
+ attn_implementation: str = "flash_attention_2",
164
+ ):
165
+ # Args
166
+ if args is None:
167
+ model_name = model if isinstance(model, str) else model.config._name_or_path
168
+ model_name = model_name.split("/")[-1]
169
+ args = GRPOConfig(f"{model_name}-GRPO")
170
+
171
+ # Models
172
+ # Trained model
173
+ model_init_kwargs = args.model_init_kwargs or {}
174
+ model_init_kwargs["attn_implementation"] = attn_implementation
175
+ if isinstance(model, str):
176
+ model_id = model
177
+ torch_dtype = model_init_kwargs.get("torch_dtype")
178
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
179
+ pass # torch_dtype is already a torch.dtype or "auto" or None
180
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
181
+ torch_dtype = getattr(torch, torch_dtype)
182
+ model_init_kwargs["torch_dtype"] = torch_dtype
183
+ else:
184
+ raise ValueError(
185
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
186
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
187
+ )
188
+ # Disable caching if gradient checkpointing is enabled (not supported)
189
+ model_init_kwargs["use_cache"] = (
190
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
191
+ )
192
+ if "Qwen2-VL" in model_id:
193
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
194
+ elif "Qwen2.5-VL" in model_id:
195
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
196
+ elif "Aria" in model_id:
197
+ model_init_kwargs.pop("use_cache")
198
+ model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
199
+ else:
200
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
201
+ else:
202
+ model_id = model.config._name_or_path
203
+ if args.model_init_kwargs is not None:
204
+ raise ValueError(
205
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
206
+ "This argument can only be used when the `model` argument is a string."
207
+ )
208
+
209
+ if peft_config is not None:
210
+ model = get_peft_model(model, peft_config)
211
+
212
+ # Reference model
213
+ if is_deepspeed_zero3_enabled():
214
+ if "Qwen2-VL" in model_id:
215
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
216
+ elif "Qwen2.5-VL" in model_id:
217
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
218
+ elif "Aria" in model_id:
219
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
220
+ else:
221
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
222
+ elif peft_config is None:
223
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
224
+ self.ref_model = create_reference_model(model)
225
+ else:
226
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
227
+ # to revert to the initial model.
228
+ self.ref_model = None
229
+
230
+ # Processing class
231
+ if processing_class is None:
232
+ if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id:
233
+ processing_class = AutoProcessor.from_pretrained(model_id)
234
+ pad_token_id = processing_class.tokenizer.pad_token_id
235
+ processing_class.pad_token_id = pad_token_id
236
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
237
+ if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
238
+ processing_class.image_processor.max_pixels = max_pixels
239
+ processing_class.image_processor.min_pixels = min_pixels
240
+ else:
241
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
242
+ pad_token_id = processing_class.pad_token_id
243
+
244
+ # Reward functions
245
+ if not isinstance(reward_funcs, list):
246
+ reward_funcs = [reward_funcs]
247
+ for i, reward_func in enumerate(reward_funcs):
248
+ if isinstance(reward_func, str):
249
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
250
+ reward_func, num_labels=1, **model_init_kwargs
251
+ )
252
+ self.reward_funcs = reward_funcs
253
+
254
+ # Reward processing class
255
+ if reward_processing_classes is None:
256
+ reward_processing_classes = [None] * len(reward_funcs)
257
+ elif not isinstance(reward_processing_classes, list):
258
+ reward_processing_classes = [reward_processing_classes]
259
+ else:
260
+ if len(reward_processing_classes) != len(reward_funcs):
261
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
262
+
263
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
264
+ if isinstance(reward_func, PreTrainedModel):
265
+ if reward_processing_class is None:
266
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
267
+ if reward_processing_class.pad_token_id is None:
268
+ reward_processing_class.pad_token = reward_processing_class.eos_token
269
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
270
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
271
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
272
+ reward_processing_classes[i] = reward_processing_class
273
+ self.reward_processing_classes = reward_processing_classes
274
+
275
+ # Data collator
276
+ def data_collator(features): # No data collation is needed in GRPO
277
+ return features
278
+
279
+ # Training arguments
280
+ self.max_prompt_length = args.max_prompt_length
281
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
282
+ self.num_generations = args.num_generations # = G in the GRPO paper
283
+ self.generation_config = GenerationConfig(
284
+ max_new_tokens=self.max_completion_length,
285
+ do_sample=True,
286
+ temperature=1, # HACK
287
+ num_return_sequences=self.num_generations,
288
+ pad_token_id=pad_token_id,
289
+ )
290
+ self.beta = args.beta
291
+
292
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
293
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
294
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
295
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
296
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
297
+ # This acts as a flag to indicate that the warning has already been issued.
298
+ model.warnings_issued["estimate_tokens"] = True
299
+
300
+ # Initialize the metrics
301
+ self._metrics = defaultdict(list)
302
+
303
+ super().__init__(
304
+ model=model,
305
+ args=args,
306
+ data_collator=data_collator,
307
+ train_dataset=train_dataset,
308
+ eval_dataset=eval_dataset,
309
+ processing_class=processing_class,
310
+ callbacks=callbacks,
311
+ optimizers=optimizers,
312
+ )
313
+
314
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
315
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
316
+ # self.model_accepts_loss_kwargs to False to enable scaling.
317
+ self.model_accepts_loss_kwargs = False
318
+
319
+ if self.ref_model is not None:
320
+ if self.is_deepspeed_enabled:
321
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
322
+ else:
323
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
324
+
325
+ for i, reward_func in enumerate(self.reward_funcs):
326
+ if isinstance(reward_func, PreTrainedModel):
327
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
328
+
329
+ def _set_signature_columns_if_needed(self):
330
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
331
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
332
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
333
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
334
+ if self._signature_columns is None:
335
+ self._signature_columns = ["prompt"]
336
+
337
+
338
+ # Get the per-token log probabilities for the completions for the model and the reference model
339
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
340
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
341
+ logits = model(input_ids, **kwargs).logits
342
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
343
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
344
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
345
+ per_token_logps = []
346
+ for logits_row, input_ids_row in zip(logits, input_ids):
347
+ log_probs = logits_row.log_softmax(dim=-1)
348
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
349
+ per_token_logps.append(token_log_prob)
350
+ return torch.stack(per_token_logps)
351
+
352
+
353
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
354
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
355
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
356
+ return inputs
357
+
358
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
359
+ if return_outputs:
360
+ raise ValueError("The GRPOTrainer does not support returning outputs")
361
+
362
+
363
+
364
+ prompts = [x["prompt"] for x in inputs]
365
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
366
+ if "image" in inputs[0]:
367
+ images = [x["image"] for x in inputs]
368
+ elif "video_filename" in inputs[0]:
369
+ video_inputs = []
370
+ for (cur_idx, cur_input) in enumerate(inputs):
371
+ copy_input = cur_input.copy()
372
+
373
+ copy_input['prompt'][0]['content'][0]['video'] = os.getcwd() + "/data" + inputs[cur_idx]["video_filename"][1:]
374
+
375
+ video_inputs.append(process_vision_info(copy_input["prompt"])[1])
376
+
377
+ # import pdb
378
+ # pdb.set_trace()
379
+
380
+
381
+
382
+
383
+
384
+ prompt_inputs = self.processing_class(
385
+ text=prompts_text,
386
+ images=images if "image" in inputs[0] else None,
387
+ videos=video_inputs if "video_filename" in inputs[0] else None,
388
+ return_tensors="pt",
389
+ padding=True,
390
+ padding_side="left",
391
+ add_special_tokens=False,
392
+ )
393
+
394
+ # import pdb
395
+ # pdb.set_trace()
396
+
397
+
398
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
399
+
400
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
401
+ # pixel_values = prompt_inputs["pixel_values"]
402
+ # image_grid_thw = prompt_inputs["image_grid_thw"]
403
+
404
+
405
+ if self.max_prompt_length is not None:
406
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
407
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
408
+
409
+ # Generate completions
410
+ # with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
411
+ # prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
412
+
413
+ # prompt_length = prompt_ids.size(1)
414
+ # prompt_ids = prompt_completion_ids[:, :prompt_length]
415
+ # completion_ids = prompt_completion_ids[:, prompt_length:]
416
+ # prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
417
+
418
+
419
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
420
+ # prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
421
+
422
+ # Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613
423
+ num_generations = self.generation_config.num_return_sequences
424
+ temp_generation_config = copy.deepcopy(self.generation_config)
425
+ temp_generation_config.num_return_sequences = 1
426
+
427
+ all_completions = []
428
+
429
+ for i in range(num_generations): # -1 because we already have one generation
430
+ completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
431
+ all_completions.append(completion)
432
+
433
+ # Stack all completions and pad if needed
434
+ max_length = max(completion.size(1) for completion in all_completions)
435
+ padded_completions = []
436
+
437
+ for completion in all_completions:
438
+ if completion.size(1) < max_length:
439
+ padding = torch.full(
440
+ (completion.size(0), max_length - completion.size(1)),
441
+ self.processing_class.tokenizer.pad_token_id,
442
+ dtype=completion.dtype,
443
+ device=completion.device,
444
+ )
445
+ padded_completion = torch.cat([completion, padding], dim=1)
446
+ else:
447
+ padded_completion = completion
448
+ padded_completions.append(padded_completion)
449
+
450
+ # Stack all padded completions
451
+ prompt_completion_ids = torch.cat(padded_completions, dim=0)
452
+
453
+ prompt_length = prompt_inputs["input_ids"].size(1)
454
+ completion_ids = prompt_completion_ids[:, prompt_length:]
455
+
456
+ # Mask everything after the first EOS token
457
+ is_eos = completion_ids == self.processing_class.eos_token_id
458
+ device = self.accelerator.device
459
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
460
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
461
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
462
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
463
+
464
+ # Concatenate prompt_mask with completion_mask for logit computation
465
+ # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
466
+ # pixel_values = prompt_inputs["pixel_values"].repeat(self.num_generations, 1)
467
+ # image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0)
468
+
469
+
470
+
471
+ prompt_inputs.pop("input_ids")
472
+ prompt_inputs.pop("attention_mask")
473
+ # Okay I am assuming that the inputs are Qwen2VL processor
474
+ # and no video for now, repeat the image for each completion
475
+ if "image" in inputs[0]:
476
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
477
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
478
+ # import pdb; pdb.set_trace()
479
+
480
+ if "video_filename" in inputs[0]:
481
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
482
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
483
+
484
+
485
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
486
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
487
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
488
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
489
+
490
+
491
+
492
+ with torch.inference_mode():
493
+ if self.ref_model is not None:
494
+ #ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
495
+ # ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
496
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids)
497
+ else:
498
+ with self.accelerator.unwrap_model(model).disable_adapter():
499
+ #ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
500
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
501
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
502
+
503
+ # Compute the KL divergence between the model and the reference model
504
+
505
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
506
+
507
+ per_token_kl = torch.clamp(per_token_kl, min=-100, max=100)
508
+
509
+ # import pdb
510
+ # pdb.set_trace()
511
+
512
+ # Decode the generated completions
513
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
514
+ if is_conversational(inputs[0]):
515
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
516
+
517
+ # Compute the rewards
518
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
519
+
520
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
521
+ for i, (reward_func, reward_processing_class) in enumerate(
522
+ zip(self.reward_funcs, self.reward_processing_classes)
523
+ ):
524
+ if isinstance(reward_func, PreTrainedModel):
525
+ if is_conversational(inputs[0]):
526
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
527
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
528
+ else:
529
+ texts = [p + c for p, c in zip(prompts, completions)]
530
+ reward_inputs = reward_processing_class(
531
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
532
+ )
533
+ reward_inputs = super()._prepare_inputs(reward_inputs)
534
+ with torch.inference_mode():
535
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
536
+ else:
537
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
538
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
539
+ for key in reward_kwargs:
540
+ for example in inputs:
541
+ # Repeat each value in the column for `num_generations` times
542
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
543
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
544
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
545
+
546
+ # Sum the rewards from all reward functions
547
+ rewards = rewards_per_func.sum(dim=1)
548
+
549
+ # Compute grouped-wise rewards
550
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
551
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
552
+
553
+ # Normalize the rewards to compute the advantages
554
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
555
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
556
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
557
+
558
+ # x - x.detach() allows for preserving gradients from x
559
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
560
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
561
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
562
+
563
+ # import pdb
564
+ # pdb.set_trace()
565
+
566
+ # Log the metrics
567
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
568
+ self._metrics["completion_length"].append(completion_length)
569
+
570
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
571
+ for i, reward_func in enumerate(self.reward_funcs):
572
+ if isinstance(reward_func, PreTrainedModel):
573
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
574
+ else:
575
+ reward_func_name = reward_func.__name__
576
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
577
+
578
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
579
+
580
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
581
+
582
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
583
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
584
+
585
+ return loss
586
+
587
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
588
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
589
+ logs = {**logs, **metrics}
590
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
591
+ super().log(logs, start_time)
592
+ else: # transformers<=4.46
593
+ super().log(logs)
594
+ self._metrics.clear()
595
+
596
+ def create_model_card(
597
+ self,
598
+ model_name: Optional[str] = None,
599
+ dataset_name: Optional[str] = None,
600
+ tags: Union[str, list[str], None] = None,
601
+ ):
602
+ """
603
+ Creates a draft of a model card using the information available to the `Trainer`.
604
+
605
+ Args:
606
+ model_name (`str` or `None`, *optional*, defaults to `None`):
607
+ Name of the model.
608
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
609
+ Name of the dataset used for training.
610
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
611
+ Tags to be associated with the model card.
612
+ """
613
+ if not self.is_world_process_zero():
614
+ return
615
+
616
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
617
+ base_model = self.model.config._name_or_path
618
+ else:
619
+ base_model = None
620
+
621
+ tags = tags or []
622
+ if isinstance(tags, str):
623
+ tags = [tags]
624
+
625
+ if hasattr(self.model.config, "unsloth_version"):
626
+ tags.append("unsloth")
627
+
628
+ citation = textwrap.dedent(
629
+ """\
630
+ @article{zhihong2024deepseekmath,
631
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
632
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
633
+ year = 2024,
634
+ eprint = {arXiv:2402.03300},
635
+ """
636
+ )
637
+
638
+ model_card = generate_model_card(
639
+ base_model=base_model,
640
+ model_name=model_name,
641
+ hub_model_id=self.hub_model_id,
642
+ dataset_name=dataset_name,
643
+ tags=tags,
644
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
645
+ comet_url=get_comet_experiment_url(),
646
+ trainer_name="GRPO",
647
+ trainer_citation=citation,
648
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
649
+ paper_id="2402.03300",
650
+ )
651
+
652
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
previous_version/Video-R1-main-previous/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ from accelerate.utils.other import is_compiled_module
20
+ from accelerate.utils import broadcast_object_list, gather, gather_object
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ import warnings
25
+ from unittest.mock import patch
26
+ from datasets import Dataset, IterableDataset
27
+ from packaging import version
28
+ from transformers import (
29
+ AriaForConditionalGeneration,
30
+ AriaProcessor,
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoProcessor,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Qwen2VLForConditionalGeneration,
39
+ Trainer,
40
+ TrainerCallback,
41
+ is_wandb_available,
42
+ )
43
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
44
+ from transformers.utils import is_peft_available
45
+
46
+ from trl.data_utils import (
47
+ apply_chat_template,
48
+ is_conversational,
49
+ maybe_apply_chat_template,
50
+ )
51
+ from trl.import_utils import is_vllm_available
52
+
53
+ from trl.models import (
54
+ create_reference_model,
55
+ prepare_deepspeed,
56
+ unwrap_model_for_generation,
57
+ )
58
+ from trl.trainer.grpo_config import GRPOConfig
59
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
60
+ from trl import GRPOTrainer
61
+
62
+ import copy
63
+
64
+ if is_peft_available():
65
+ from peft import PeftConfig, get_peft_model
66
+
67
+ if is_vllm_available():
68
+ from vllm import LLM, SamplingParams
69
+
70
+
71
+ if is_wandb_available():
72
+ import wandb
73
+ import torch.nn as nn
74
+ from torch.utils.data import Sampler
75
+
76
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
77
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
78
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
79
+
80
+
81
+ class RepeatRandomSampler(Sampler):
82
+ """
83
+ Sampler that repeats the indices of a dataset N times.
84
+
85
+ Args:
86
+ data_source (`Sized`):
87
+ Dataset to sample from.
88
+ repeat_count (`int`):
89
+ Number of times to repeat each index.
90
+
91
+ Example:
92
+ ```python
93
+ >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
94
+ >>> list(sampler)
95
+ [2, 2, 0, 0, 3, 3, 1, 1]
96
+ ```
97
+ """
98
+
99
+ def __init__(self, data_source, repeat_count: int):
100
+ self.data_source = data_source
101
+ self.repeat_count = repeat_count
102
+ self.num_samples = len(data_source)
103
+
104
+ def __iter__(self):
105
+ indexes = [
106
+ idx
107
+ for idx in torch.randperm(self.num_samples).tolist()
108
+ for _ in range(self.repeat_count)
109
+ ]
110
+ return iter(indexes)
111
+
112
+ def __len__(self):
113
+ return self.num_samples * self.repeat_count
114
+
115
+
116
+ class Qwen2VLGRPOVLLMTrainer(Trainer):
117
+ def __init__(
118
+ self,
119
+ model: Union[str, PreTrainedModel],
120
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
121
+ args: GRPOConfig = None,
122
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
123
+ eval_dataset: Optional[
124
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
125
+ ] = None,
126
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
127
+ reward_processing_classes: Optional[
128
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
129
+ ] = None,
130
+ callbacks: Optional[list[TrainerCallback]] = None,
131
+ optimizers: tuple[
132
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
133
+ ] = (None, None),
134
+ peft_config: Optional["PeftConfig"] = None,
135
+ # qwen2-vl related params
136
+ max_pixels: Optional[int] = 12845056,
137
+ min_pixels: Optional[int] = 3136,
138
+ attn_implementation: str = "flash_attention_2",
139
+ ):
140
+
141
+ # Args
142
+ if args is None:
143
+ model_name = model if isinstance(model, str) else model.config._name_or_path
144
+ model_name = model_name.split("/")[-1]
145
+ args = GRPOConfig(f"{model_name}-GRPO")
146
+
147
+ # Models
148
+ # Trained model
149
+ model_init_kwargs = args.model_init_kwargs or {}
150
+ model_init_kwargs["attn_implementation"] = attn_implementation
151
+ if isinstance(model, str):
152
+ model_id = model
153
+ torch_dtype = model_init_kwargs.get("torch_dtype")
154
+ if (
155
+ isinstance(torch_dtype, torch.dtype)
156
+ or torch_dtype == "auto"
157
+ or torch_dtype is None
158
+ ):
159
+ pass # torch_dtype is already a torch.dtype or "auto" or None
160
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
161
+ torch_dtype = getattr(torch, torch_dtype)
162
+ model_init_kwargs["torch_dtype"] = torch_dtype
163
+ else:
164
+ raise ValueError(
165
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
166
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
167
+ )
168
+ # Disable caching if gradient checkpointing is enabled (not supported)
169
+ model_init_kwargs["use_cache"] = (
170
+ False
171
+ if args.gradient_checkpointing
172
+ else model_init_kwargs.get("use_cache")
173
+ )
174
+ if "Qwen2-VL" in model_id:
175
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
176
+ model, **model_init_kwargs
177
+ )
178
+ elif "Aria" in model_id:
179
+ model_init_kwargs.pop("use_cache")
180
+ model = AriaForConditionalGeneration.from_pretrained(
181
+ model, **model_init_kwargs
182
+ )
183
+ else:
184
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
185
+ else:
186
+ model_id = model.config._name_or_path
187
+ if args.model_init_kwargs is not None:
188
+ raise ValueError(
189
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
190
+ "This argument can only be used when the `model` argument is a string."
191
+ )
192
+
193
+ if peft_config is not None:
194
+ model = get_peft_model(model, peft_config)
195
+
196
+ # Reference model
197
+ if is_deepspeed_zero3_enabled():
198
+ if "Qwen2-VL" in model_id:
199
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
200
+ model_id, **model_init_kwargs
201
+ )
202
+ elif "Aria" in model_id:
203
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
204
+ model_id, **model_init_kwargs
205
+ )
206
+ else:
207
+ self.ref_model = AutoModelForCausalLM.from_pretrained(
208
+ model_id, **model_init_kwargs
209
+ )
210
+ elif peft_config is None:
211
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
212
+ self.ref_model = create_reference_model(model)
213
+ else:
214
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
215
+ # to revert to the initial model.
216
+ self.ref_model = None
217
+
218
+ # Processing class
219
+ if processing_class is None:
220
+ if "Qwen2-VL" in model_id or "Aria" in model_id:
221
+ processing_class = AutoProcessor.from_pretrained(model_id)
222
+ pad_token_id = processing_class.tokenizer.pad_token_id
223
+ processing_class.pad_token_id = pad_token_id
224
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
225
+ if "Qwen2-VL" in model_id:
226
+ processing_class.image_processor.max_pixels = max_pixels
227
+ processing_class.image_processor.min_pixels = min_pixels
228
+ else:
229
+ processing_class = AutoTokenizer.from_pretrained(
230
+ model.config._name_or_path, padding_side="left"
231
+ )
232
+ pad_token_id = processing_class.pad_token_id
233
+
234
+ # Reward functions
235
+ if not isinstance(reward_funcs, list):
236
+ reward_funcs = [reward_funcs]
237
+ for i, reward_func in enumerate(reward_funcs):
238
+ if isinstance(reward_func, str):
239
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
240
+ reward_func, num_labels=1, **model_init_kwargs
241
+ )
242
+ self.reward_funcs = reward_funcs
243
+
244
+ # Reward processing class
245
+ if reward_processing_classes is None:
246
+ reward_processing_classes = [None] * len(reward_funcs)
247
+ elif not isinstance(reward_processing_classes, list):
248
+ reward_processing_classes = [reward_processing_classes]
249
+ else:
250
+ if len(reward_processing_classes) != len(reward_funcs):
251
+ raise ValueError(
252
+ "The number of reward processing classes must match the number of reward functions."
253
+ )
254
+
255
+ for i, (reward_processing_class, reward_func) in enumerate(
256
+ zip(reward_processing_classes, reward_funcs)
257
+ ):
258
+ if isinstance(reward_func, PreTrainedModel):
259
+ if reward_processing_class is None:
260
+ reward_processing_class = AutoTokenizer.from_pretrained(
261
+ reward_func.config._name_or_path
262
+ )
263
+ if reward_processing_class.pad_token_id is None:
264
+ reward_processing_class.pad_token = (
265
+ reward_processing_class.eos_token
266
+ )
267
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
268
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
269
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
270
+ reward_processing_classes[i] = reward_processing_class
271
+ self.reward_processing_classes = reward_processing_classes
272
+
273
+ # Data collator
274
+ def data_collator(features): # No data collation is needed in GRPO
275
+ return features
276
+
277
+ # Training arguments
278
+ self.max_prompt_length = args.max_prompt_length
279
+ self.max_completion_length = (
280
+ args.max_completion_length
281
+ ) # = |o_i| in the GRPO paper
282
+ self.num_generations = args.num_generations # = G in the GRPO paper
283
+ self.generation_config = GenerationConfig(
284
+ max_new_tokens=self.max_completion_length,
285
+ do_sample=True,
286
+ temperature=1, # HACK
287
+ num_return_sequences=self.num_generations,
288
+ pad_token_id=pad_token_id,
289
+ )
290
+ self.beta = args.beta
291
+
292
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
293
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
294
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
295
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
296
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
297
+ # This acts as a flag to indicate that the warning has already been issued.
298
+ model.warnings_issued["estimate_tokens"] = True
299
+
300
+ # Initialize the metrics
301
+ self._metrics = defaultdict(list)
302
+ self.use_vllm = args.use_vllm
303
+
304
+ # rewrite the processing AutoTokenizer -> AutoProcessor
305
+ model_id = model if isinstance(model, str) else model.config._name_or_path
306
+ if processing_class is None:
307
+ if "Qwen2-VL" in model_id or "Aria" in model_id:
308
+ processing_class = AutoProcessor.from_pretrained(model_id)
309
+ pad_token_id = processing_class.tokenizer.pad_token_id
310
+ processing_class.pad_token_id = pad_token_id
311
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
312
+ if "Qwen2-VL" in model_id:
313
+ processing_class.image_processor.max_pixels = max_pixels
314
+ processing_class.image_processor.min_pixels = min_pixels
315
+ else:
316
+ processing_class = AutoTokenizer.from_pretrained(
317
+ model.config._name_or_path, padding_side="left"
318
+ )
319
+ pad_token_id = processing_class.pad_token_id
320
+
321
+ super().__init__(
322
+ model=model,
323
+ args=args,
324
+ data_collator=data_collator,
325
+ train_dataset=train_dataset,
326
+ eval_dataset=eval_dataset,
327
+ processing_class=processing_class,
328
+ callbacks=callbacks,
329
+ optimizers=optimizers,
330
+ )
331
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
332
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
333
+ # self.model_accepts_loss_kwargs to False to enable scaling.
334
+ self.model_accepts_loss_kwargs = False
335
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
336
+ num_processes = self.accelerator.num_processes
337
+ global_batch_size = args.per_device_train_batch_size * num_processes
338
+ possible_values = [
339
+ n_gen
340
+ for n_gen in range(2, global_batch_size + 1)
341
+ if (global_batch_size) % n_gen == 0
342
+ ]
343
+
344
+ if self.num_generations not in possible_values:
345
+ raise ValueError(
346
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
347
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
348
+ f"batch size, the valid values for the number of generations are: {possible_values}."
349
+ )
350
+ if self.args.eval_strategy != "no":
351
+ global_batch_size = args.per_device_eval_batch_size * num_processes
352
+ possible_values = [
353
+ n_gen
354
+ for n_gen in range(2, global_batch_size + 1)
355
+ if (global_batch_size) % n_gen == 0
356
+ ]
357
+ if self.num_generations not in possible_values:
358
+ raise ValueError(
359
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
360
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
361
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
362
+ )
363
+
364
+ if self.use_vllm:
365
+ if not is_vllm_available():
366
+ raise ImportError(
367
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
368
+ "`pip install vllm` to use it."
369
+ )
370
+
371
+ if self.accelerator.is_main_process:
372
+ vllm_device = self.args.vllm_device
373
+ if vllm_device == "auto":
374
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
375
+ # Check that the requested device is available
376
+ if (
377
+ vllm_device.split(":")[0] == "cuda"
378
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
379
+ ):
380
+ raise ValueError(
381
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
382
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
383
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
384
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
385
+ )
386
+ # Check that the requested device is not also used for training
387
+ if vllm_device in {
388
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
389
+ }:
390
+ warnings.warn(
391
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
392
+ "behavior. It is recommended to use a dedicated device for vLLM."
393
+ )
394
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
395
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
396
+ # setting (profiling_patch).
397
+ world_size_patch = patch(
398
+ "torch.distributed.get_world_size", return_value=1
399
+ )
400
+ profiling_patch = patch(
401
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
402
+ return_value=None,
403
+ )
404
+ with world_size_patch, profiling_patch:
405
+ print("vllm is running on: ", vllm_device)
406
+ self.llm = LLM(
407
+ model=model.name_or_path,
408
+ device=vllm_device,
409
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
410
+ dtype=torch.bfloat16,
411
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
412
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
413
+ # This is particularly useful here because we generate completions from the same prompts.
414
+ enable_prefix_caching=True,
415
+ enforce_eager=True,
416
+ # Ensure that training and inference use the same processor for images.
417
+ mm_processor_kwargs=(
418
+ {
419
+ "max_pixels": max_pixels,
420
+ "min_pixels": min_pixels,
421
+ }
422
+ if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
423
+ else None
424
+ ),
425
+ max_model_len=args.max_completion_length,
426
+ )
427
+ self.sampling_params = SamplingParams(
428
+ temperature=args.temperature,
429
+ max_tokens=self.max_completion_length,
430
+ )
431
+
432
+ self._last_loaded_step = (
433
+ 0 # tag to avoid useless loading during grad accumulation
434
+ )
435
+
436
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
437
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
438
+ # synchronize all processes after vLLM has been fully initialized.
439
+ self.accelerator.wait_for_everyone()
440
+ else:
441
+ raise ValueError(
442
+ "Qwen2VLGRPOVLLMTrainer only supports vllm generation, please set --use_vllm True"
443
+ )
444
+
445
+ if self.ref_model is not None:
446
+ if self.is_deepspeed_enabled:
447
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
448
+ else:
449
+ self.ref_model = self.accelerator.prepare_model(
450
+ self.ref_model, evaluation_mode=True
451
+ )
452
+
453
+ for i, reward_func in enumerate(self.reward_funcs):
454
+ if isinstance(reward_func, PreTrainedModel):
455
+ self.reward_funcs[i] = self.accelerator.prepare_model(
456
+ reward_func, evaluation_mode=True
457
+ )
458
+
459
+ def _set_signature_columns_if_needed(self):
460
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
461
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
462
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
463
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
464
+ if self._signature_columns is None:
465
+ self._signature_columns = ["prompt"]
466
+
467
+ # We need a custom sampler that samples the same prompt multiple times
468
+ def _get_train_sampler(self):
469
+ return RepeatRandomSampler(self.train_dataset, self.num_generations)
470
+
471
+ # Get the per-token log probabilities for the completions for the model and the reference model
472
+ def _get_per_token_logps(
473
+ self,
474
+ model,
475
+ input_ids,
476
+ attention_mask,
477
+ pixel_values,
478
+ image_grid_thw,
479
+ logits_to_keep,
480
+ ):
481
+ pixel_values = pixel_values.to(model.device)
482
+ image_grid_thw = image_grid_thw.to(device=model.device)
483
+ logits = model(
484
+ input_ids,
485
+ attention_mask=attention_mask,
486
+ pixel_values=pixel_values,
487
+ image_grid_thw=image_grid_thw,
488
+ ).logits # (B, L, V)
489
+ logits = logits[
490
+ :, :-1, :
491
+ ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
492
+ input_ids = input_ids[
493
+ :, -logits_to_keep:
494
+ ] # (B, L-1), exclude the first input ID since we don't have logits for it
495
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
496
+ logits = logits[:, -logits_to_keep:]
497
+ per_token_logps = []
498
+ for logits_row, input_ids_row in zip(logits, input_ids):
499
+ log_probs = logits_row.log_softmax(dim=-1)
500
+ token_log_prob = torch.gather(
501
+ log_probs, dim=1, index=input_ids_row.unsqueeze(1)
502
+ ).squeeze(1)
503
+ per_token_logps.append(token_log_prob)
504
+ return torch.stack(per_token_logps)
505
+
506
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
507
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
508
+ def _prepare_inputs(
509
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
510
+ ) -> dict[str, Union[torch.Tensor, Any]]:
511
+ device = self.accelerator.device
512
+ prompts = [x["prompt"] for x in inputs]
513
+ images = [x["image"] for x in inputs]
514
+ prompts_text = [
515
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
516
+ for example in inputs
517
+ ]
518
+ prompt_inputs = self.processing_class(
519
+ # prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
520
+ text=prompts_text,
521
+ images=images,
522
+ return_tensors="pt",
523
+ padding=True,
524
+ padding_side="left",
525
+ add_special_tokens=False,
526
+ )
527
+ prompt_ids, prompt_mask = (
528
+ prompt_inputs["input_ids"].to(device),
529
+ prompt_inputs["attention_mask"].to(device),
530
+ )
531
+ if self.max_prompt_length is not None:
532
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
533
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
534
+
535
+ if self.args.use_vllm:
536
+ # First, have main process load weights if needed
537
+ if self.state.global_step != self._last_loaded_step:
538
+ with unwrap_model_for_generation(
539
+ self.model,
540
+ self.accelerator,
541
+ gather_deepspeed3_params=False, # TODO: fix this, self.args.ds3_gather_for_generation,
542
+ ) as unwrapped_model:
543
+ if is_compiled_module(unwrapped_model):
544
+ state_dict = unwrapped_model._orig_mod.state_dict()
545
+ else:
546
+ state_dict = unwrapped_model.state_dict()
547
+ if self.accelerator.is_main_process:
548
+ llm_model = (
549
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
550
+ )
551
+ llm_model.load_weights(state_dict.items())
552
+ self._last_loaded_step = self.state.global_step
553
+
554
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
555
+ all_prompts_text = gather_object(prompts_text)
556
+ all_images = gather_object(images)
557
+ # group into pairs
558
+ all_multimodal_inputs = [
559
+ {"prompt": p, "multi_modal_data": {"image": i}}
560
+ for p, i in zip(all_prompts_text, all_images)
561
+ ]
562
+
563
+ if self.accelerator.is_main_process:
564
+ outputs = self.llm.generate(
565
+ all_multimodal_inputs,
566
+ sampling_params=self.sampling_params,
567
+ use_tqdm=False,
568
+ )
569
+ completion_ids = [
570
+ out.token_ids
571
+ for completions in outputs
572
+ for out in completions.outputs
573
+ ]
574
+ else:
575
+ completion_ids = [None] * len(all_prompts_text)
576
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
577
+ process_slice = slice(
578
+ self.accelerator.process_index * len(prompts),
579
+ (self.accelerator.process_index + 1) * len(prompts),
580
+ )
581
+ completion_ids = completion_ids[process_slice]
582
+
583
+ # Pad the completions, and concatenate them with the prompts
584
+ completion_ids = [
585
+ torch.tensor(ids, device=device) for ids in completion_ids
586
+ ]
587
+ completion_ids = pad(
588
+ completion_ids, padding_value=self.processing_class.pad_token_id
589
+ )
590
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
591
+ else:
592
+ raise ValueError("Only vLLM generation is supported in this version ")
593
+
594
+ # below are the same with yifan's code
595
+ # Mask everything after the first EOS token
596
+ is_eos = completion_ids == self.processing_class.eos_token_id
597
+ device = self.accelerator.device
598
+ eos_idx = torch.full(
599
+ (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
600
+ )
601
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
602
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
603
+ is_eos.size(0), -1
604
+ )
605
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
606
+
607
+ # Concatenate prompt_mask with completion_mask for logit computation
608
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
609
+ # pixel_values = prompt_inputs["pixel_values"].repeat_interleave(
610
+ # self.num_generations, dim=0
611
+ # )
612
+
613
+ pixel_values = prompt_inputs["pixel_values"]
614
+ # [None].repeat_interleave(self.num_generations, dim=0)
615
+ # pixel_values = pixel_values.view(-1, pixel_values.shape[-1])
616
+
617
+ image_grid_thw = prompt_inputs["image_grid_thw"]
618
+ # .repeat_interleave(
619
+ # self.num_generations, dim=0
620
+ # )
621
+ logits_to_keep = completion_ids.size(1)
622
+
623
+ with torch.inference_mode():
624
+ if self.ref_model is not None:
625
+ ref_per_token_logps = self._get_per_token_logps(
626
+ self.ref_model,
627
+ prompt_completion_ids,
628
+ attention_mask,
629
+ pixel_values,
630
+ image_grid_thw,
631
+ logits_to_keep,
632
+ )
633
+ else:
634
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
635
+ ref_per_token_logps = self._get_per_token_logps(
636
+ self.model,
637
+ prompt_completion_ids,
638
+ attention_mask,
639
+ pixel_values,
640
+ image_grid_thw,
641
+ logits_to_keep,
642
+ )
643
+
644
+ # Decode the generated completions
645
+ completions = self.processing_class.batch_decode(
646
+ completion_ids, skip_special_tokens=True
647
+ )
648
+ if is_conversational(inputs[0]):
649
+ completions = [
650
+ [{"role": "assistant", "content": completion}]
651
+ for completion in completions
652
+ ]
653
+
654
+ # Compute the rewards
655
+ rewards_per_func = torch.zeros(
656
+ len(prompts), len(self.reward_funcs), device=device
657
+ )
658
+ for i, (reward_func, reward_processing_class) in enumerate(
659
+ zip(self.reward_funcs, self.reward_processing_classes)
660
+ ):
661
+ if isinstance(reward_func, PreTrainedModel):
662
+ if is_conversational(inputs[0]):
663
+ messages = [
664
+ {"messages": p + c} for p, c in zip(prompts, completions)
665
+ ]
666
+ texts = [
667
+ apply_chat_template(x, reward_processing_class)["text"]
668
+ for x in messages
669
+ ]
670
+ else:
671
+ texts = [p + c for p, c in zip(prompts, completions)]
672
+ reward_inputs = reward_processing_class(
673
+ texts,
674
+ return_tensors="pt",
675
+ padding=True,
676
+ padding_side="right",
677
+ add_special_tokens=False,
678
+ )
679
+ reward_inputs = super()._prepare_inputs(reward_inputs)
680
+ with torch.inference_mode():
681
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
682
+ :, 0
683
+ ] # Shape (B*G,)
684
+ else:
685
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
686
+ reward_kwargs = {
687
+ key: []
688
+ for key in inputs[0].keys()
689
+ if key not in ["prompt", "completion"]
690
+ }
691
+ for key in reward_kwargs:
692
+ for example in inputs:
693
+ # Repeat each value in the column for `num_generations` times
694
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
695
+ output_reward_func = reward_func(
696
+ prompts=prompts, completions=completions, **reward_kwargs
697
+ )
698
+ rewards_per_func[:, i] = torch.tensor(
699
+ output_reward_func, dtype=torch.float32, device=device
700
+ )
701
+ rewards_per_func = gather(rewards_per_func)
702
+ # Sum the rewards from all reward functions
703
+ rewards = rewards_per_func.sum(dim=1)
704
+
705
+ # Compute grouped-wise rewards
706
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
707
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
708
+
709
+ # Normalize the rewards to compute the advantages
710
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
711
+ self.num_generations, dim=0
712
+ )
713
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(
714
+ self.num_generations, dim=0
715
+ )
716
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
717
+
718
+ # Slice to keep only the local part of the data
719
+ process_slice = slice(
720
+ self.accelerator.process_index * len(prompts),
721
+ (self.accelerator.process_index + 1) * len(prompts),
722
+ )
723
+ advantages = advantages[process_slice]
724
+
725
+ # Log the metrics
726
+ reward_per_func = rewards_per_func.mean(0)
727
+ for i, reward_func in enumerate(self.reward_funcs):
728
+ if isinstance(
729
+ reward_func, nn.Module
730
+ ): # Module instead of PretrainedModel for compat with compiled models
731
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
732
+ else:
733
+ reward_func_name = reward_func.__name__
734
+ self._metrics[f"rewards/{reward_func_name}"].append(
735
+ reward_per_func[i].item()
736
+ )
737
+
738
+ self._metrics["reward"].append(rewards.mean().item())
739
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
740
+
741
+ return {
742
+ "prompt_ids": prompt_ids,
743
+ "prompt_mask": prompt_mask,
744
+ "completion_ids": completion_ids,
745
+ "completion_mask": completion_mask,
746
+ "ref_per_token_logps": ref_per_token_logps,
747
+ "advantages": advantages,
748
+ "pixel_values": pixel_values,
749
+ "image_grid_thw": image_grid_thw,
750
+ }
751
+
752
+ def compute_loss(
753
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
754
+ ):
755
+ if return_outputs:
756
+ raise ValueError("The GRPOTrainer does not support returning outputs")
757
+ # Compute the per-token log probabilities for the model
758
+
759
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
760
+ completion_ids, completion_mask = (
761
+ inputs["completion_ids"],
762
+ inputs["completion_mask"],
763
+ )
764
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
765
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
766
+ pixel_values = inputs["pixel_values"]
767
+ image_grid_thw = inputs["image_grid_thw"]
768
+ logits_to_keep = completion_ids.size(
769
+ 1
770
+ ) # we only need to compute the logits for the completion tokens
771
+
772
+ per_token_logps = self._get_per_token_logps(
773
+ model,
774
+ input_ids,
775
+ attention_mask,
776
+ pixel_values,
777
+ image_grid_thw,
778
+ logits_to_keep,
779
+ )
780
+
781
+ # Compute the KL divergence between the model and the reference model
782
+ ref_per_token_logps = inputs["ref_per_token_logps"]
783
+ per_token_kl = (
784
+ torch.exp(ref_per_token_logps - per_token_logps)
785
+ - (ref_per_token_logps - per_token_logps)
786
+ - 1
787
+ )
788
+
789
+ # x - x.detach() allows for preserving gradients from x
790
+ advantages = inputs["advantages"]
791
+ per_token_loss = torch.exp(
792
+ per_token_logps - per_token_logps.detach()
793
+ ) * advantages.unsqueeze(1)
794
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
795
+ loss = (
796
+ (per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
797
+ ).mean()
798
+
799
+ # Log the metrics
800
+ completion_length = (
801
+ self.accelerator.gather_for_metrics(completion_mask.sum(1))
802
+ .float()
803
+ .mean()
804
+ .item()
805
+ )
806
+ self._metrics["completion_length"].append(completion_length)
807
+
808
+ mean_kl = (
809
+ (per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
810
+ ).mean()
811
+ self._metrics["kl"].append(
812
+ self.accelerator.gather_for_metrics(mean_kl).mean().item()
813
+ )
814
+
815
+ return loss
816
+
817
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
818
+ metrics = {
819
+ key: sum(val) / len(val) for key, val in self._metrics.items()
820
+ } # average the metrics
821
+
822
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
823
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
824
+ if next(iter(logs.keys())).startswith("eval_"):
825
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
826
+
827
+ logs = {**logs, **metrics}
828
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
829
+ super().log(logs, start_time)
830
+ else: # transformers<=4.46
831
+ super().log(logs)
832
+ self._metrics.clear()
setup.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the packages in r1-v .
2
+ cd src/r1-v
3
+ pip install -e ".[dev]"
4
+
5
+ # Addtional modules
6
+ pip install wandb==0.18.3
7
+ pip install tensorboardx
8
+ pip install qwen_vl_utils torchvision
9
+ pip install flash-attn --no-build-isolation
10
+
11
+ # vLLM support
12
+ pip install vllm==0.7.2
13
+
14
+ pip install nltk
15
+ pip install rouge_score
16
+ pip install deepspeed
17
+
18
+ # fix transformers version
19
+ # pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef
src/download.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ # snapshot_download(
4
+ # repo_id="Video-R1/Video-R1-data",
5
+ # repo_type="dataset", # Specify it's a dataset repo
6
+ # local_dir="Video-R1-data", # Local directory to save data
7
+ # local_dir_use_symlinks=False # Set False if you want full file copies
8
+ # )
9
+
10
+
11
+ # snapshot_download(
12
+ # repo_id="OpenGVLab/MVBench",
13
+ # repo_type="dataset", # Specify it's a dataset repo
14
+ # local_dir="Evaluation/MVBench", # Local directory to save data
15
+ # local_dir_use_symlinks=False # Set False if you want full file copies
16
+ # )
17
+
18
+
19
+ snapshot_download(
20
+ repo_id="yale-nlp/MMVU",
21
+ repo_type="dataset", # Specify it's a dataset repo
22
+ local_dir="Evaluation/MMVU", # Local directory to save data
23
+ local_dir_use_symlinks=False # Set False if you want full file copies
24
+ )
src/eval_bench.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from tqdm import tqdm
5
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
6
+ from rouge_score import rouge_scorer
7
+ import torch
8
+
9
+ from transformers import AutoProcessor, AutoTokenizer
10
+ from vllm import LLM, SamplingParams
11
+ from qwen_vl_utils import process_vision_info
12
+ import argparse
13
+
14
+
15
+ BSZ = 64
16
+
17
+
18
+ parser = argparse.ArgumentParser(description="Evaluation benchmark")
19
+ parser.add_argument('--model_path', type=str, required=True, help="Path to the model")
20
+ parser.add_argument('--file_name', type=str, required=True, help="Name of the file")
21
+ args = parser.parse_args()
22
+
23
+ MODEL_PATH = args.model_path
24
+ file_name = args.file_name
25
+
26
+
27
+
28
+ llm = LLM(
29
+ model=MODEL_PATH,
30
+ tensor_parallel_size=torch.cuda.device_count(),
31
+ # max_model_len = 8192 * 2,
32
+ max_model_len = 32768,
33
+ gpu_memory_utilization=0.75,
34
+ limit_mm_per_prompt={"image": 1, "video": 1},
35
+ )
36
+
37
+
38
+ sampling_params = SamplingParams(
39
+ temperature=0.1,
40
+ top_p=0.001,
41
+ max_tokens=1024,
42
+ stop_token_ids=[],
43
+ )
44
+
45
+
46
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
47
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
48
+ tokenizer.padding_side = "left"
49
+ processor.tokenizer = tokenizer
50
+
51
+
52
+ # for dataset_name in ['mvbench','tempcompass','videomme','videommmu','vsibench','mmvu']:
53
+ for dataset_name in ['mvbench', 'mmvu']:
54
+
55
+ OUTPUT_PATH = f"./src/r1-v/eval_results/eval_{dataset_name}_{file_name}_greedy_output.json"
56
+ PROMPT_PATH = f"./src/r1-v/Evaluation/eval_{dataset_name}.json"
57
+
58
+ if PROMPT_PATH.endswith('.jsonl'):
59
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
60
+ for line in f:
61
+ data.append(json.loads(line))
62
+ elif PROMPT_PATH.endswith('.json'):
63
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
64
+ data = json.load(f)
65
+ else:
66
+ raise ValueError("Input file must be .json or .jsonl")
67
+
68
+ QUESTION_TEMPLATE = (
69
+ "{Question}\n"
70
+ "Please think about this question as if you were a human pondering deeply. "
71
+ "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
72
+ "It's encouraged to include self-reflection or verification in the reasoning process. "
73
+ "Provide your detailed reasoning between the <think> and </think> tags, and then give your final answer between the <answer> and </answer> tags."
74
+ )
75
+
76
+ TYPE_TEMPLATE = {
77
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
78
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
79
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
80
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
81
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
82
+ }
83
+
84
+
85
+ messages = []
86
+ for x in data:
87
+ if x["problem_type"] == 'multiple choice':
88
+ question = x['problem'] + "Options:\n"
89
+ for op in x["options"]:
90
+ question += op + "\n"
91
+ else:
92
+ question = x['problem']
93
+
94
+ msg = [{
95
+ "role": "user",
96
+ "content": [
97
+ {
98
+ "type": x['data_type'],
99
+ # x['data_type']: os.getcwd() + "/src/r1-v/Evaluation" + x['path'][1:]
100
+ x['data_type']: os.getcwd() + "/src/r1-v" + x['path'][1:]
101
+ },
102
+ {
103
+ "type": "text",
104
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[x['problem_type']]
105
+ }
106
+ ]
107
+ }]
108
+ messages.append(msg)
109
+
110
+
111
+ final_output = []
112
+ start_idx = 0
113
+ if os.path.exists(OUTPUT_PATH):
114
+ try:
115
+ with open(OUTPUT_PATH, "r", encoding="utf-8") as f:
116
+ existing = json.load(f)
117
+ final_output = existing.get("results", [])
118
+ start_idx = len(final_output)
119
+ print(f"Resuming from sample index {start_idx}")
120
+ except Exception as e:
121
+ print(f"Error reading existing output file: {e}")
122
+
123
+
124
+ def extract_think(output_str):
125
+ pattern = r'<think>\s*(.*?)\s*</think>'
126
+ match = re.search(pattern, output_str, re.DOTALL)
127
+ if match:
128
+ return match.group(1).strip()
129
+ return ""
130
+
131
+ def extract_answer(text):
132
+ pattern = r'<answer>\s*(.*?)\s*</answer>'
133
+ match = re.search(pattern, text, re.DOTALL)
134
+ if match:
135
+ return match.group(1).strip()
136
+ return ""
137
+
138
+ def normalize_number(num_str):
139
+ try:
140
+ num_str = num_str.replace(',', '')
141
+ return float(num_str)
142
+ except Exception as e:
143
+ return None
144
+
145
+ def mean_relative_accuracy(pred, target, start=0.5, end=0.95, interval=0.05):
146
+
147
+ if not torch.is_tensor(pred):
148
+ pred = torch.tensor(pred, dtype=torch.float32)
149
+ if not torch.is_tensor(target):
150
+ target = torch.tensor(target, dtype=torch.float32)
151
+
152
+ epsilon = 1e-8
153
+ rel_error = torch.abs(pred - target) / (torch.abs(target) + epsilon)
154
+
155
+ thresholds = torch.arange(start, end + interval/2, interval, dtype=torch.float32)
156
+
157
+ conditions = rel_error < (1 - thresholds)
158
+ mra = conditions.float().mean()
159
+ return mra.item()
160
+
161
+
162
+ def reward_fn(sample, model_output, question_type):
163
+ try:
164
+ output_ans = extract_answer(model_output)
165
+ if output_ans == '':
166
+ output_ans = model_output
167
+ gt_ans = extract_answer(sample.get("solution", ""))
168
+ if question_type == "multiple choice":
169
+ return 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
170
+ elif question_type == "numerical":
171
+ gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
172
+ out_has_decimal = ("." in output_ans) or ("," in output_ans)
173
+ if gt_has_decimal != out_has_decimal:
174
+ return 0.0
175
+ gt_number = normalize_number(gt_ans)
176
+ out_number = normalize_number(output_ans)
177
+ if gt_number is None or out_number is None:
178
+ return 0.0
179
+ return 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
180
+ elif question_type == "regression":
181
+ gt_number = normalize_number(gt_ans)
182
+ out_number = normalize_number(output_ans)
183
+ if gt_number is None or out_number is None:
184
+ return 0.0
185
+ mra = mean_relative_accuracy(out_number, gt_number)
186
+ return mra
187
+ else:
188
+ return 0.0
189
+ except Exception as e:
190
+ return 0.0
191
+
192
+ mean_acc = []
193
+ mean_mra = []
194
+ for i in tqdm(range(start_idx, len(messages), BSZ), desc="Processing batches"):
195
+ batch_messages = messages[i:i + BSZ]
196
+
197
+ prompts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
198
+
199
+
200
+ try:
201
+ image_inputs, video_inputs, video_kwargs = process_vision_info(batch_messages, return_video_kwargs=True)
202
+
203
+ image_idx = 0
204
+ video_idx = 0
205
+
206
+ llm_inputs = []
207
+
208
+
209
+ for idx, prompt in enumerate(prompts):
210
+ mm_type = batch_messages[idx][0]['content'][0]['type']
211
+ sample_mm_data = {}
212
+ sample_video_kw = {}
213
+ if mm_type == 'image':
214
+ sample_mm_data["image"] = image_inputs[image_idx]
215
+ image_idx += 1
216
+ elif mm_type == 'video':
217
+ sample_mm_data["video"] = video_inputs[video_idx]
218
+ for key, value in video_kwargs.items():
219
+ sample_video_kw[key] = value[video_idx]
220
+ video_idx += 1
221
+
222
+
223
+ llm_inputs.append({
224
+ "prompt": prompt,
225
+ "multi_modal_data": sample_mm_data,
226
+ "mm_processor_kwargs": sample_video_kw,
227
+ })
228
+
229
+
230
+ outputs = llm.generate(llm_inputs, sampling_params=sampling_params)
231
+ batch_output_text = [out.outputs[0].text for out in outputs]
232
+
233
+ except Exception as e:
234
+ print('error:', data[i]['path'])
235
+ print('Exception:', e)
236
+ batch_output_text = ['<answer>error</answer>'] * BSZ
237
+
238
+
239
+ for j, (sample, model_output) in enumerate(zip(data[i:i+BSZ], batch_output_text), start=i):
240
+ think_chain = extract_think(model_output)
241
+ final_ans = extract_answer(model_output)
242
+ if final_ans == "":
243
+ final_ans = model_output
244
+ sample["output"] = model_output
245
+ sample["prediction"] = final_ans
246
+ q_type = sample.get("problem_type", "")
247
+ sample["reward"] = reward_fn(sample, model_output, q_type)
248
+ sample['correct'] = True if sample["reward"]==1.0 else False
249
+ if sample['problem_type'] != 'regression':
250
+ mean_acc.append(sample["reward"])
251
+ else:
252
+ mean_mra.append(sample["reward"])
253
+ if think_chain:
254
+ sample["process"] = f"<think>{think_chain}</think>"
255
+ final_output.append(sample)
256
+
257
+
258
+ try:
259
+ with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
260
+ json.dump({"results": final_output}, f, indent=2, ensure_ascii=False)
261
+ print(f"Processed batch {(i - start_idx)//BSZ + 1}, saved {len(final_output)} samples.")
262
+ except Exception as e:
263
+ print(f"Error writing to output file: {e}")
264
+
265
+ final_acc={'mean_acc': 0.0, 'mean_mra': 0.0}
266
+ final_acc['mean_acc'] = torch.tensor(mean_acc).mean().item()
267
+ if mean_mra != []:
268
+ final_acc['mean_mra'] = torch.tensor(mean_mra).mean().item()
269
+
270
+ try:
271
+ with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
272
+ json.dump({"results": final_output, "final_acc": [final_acc]}, f, indent=2, ensure_ascii=False)
273
+ print(f"Final accuracy saved to {OUTPUT_PATH}")
274
+ except Exception as e:
275
+ print(f"Error writing final accuracy to output file: {e}")
276
+
277
+ print(f"Results saved to {OUTPUT_PATH}")
src/eval_bench.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # run_models.sh
3
+
4
+ # export HF_HOME=/apdcephfs_sh2/share_300000800/user/zongxia/hf_cache
5
+ # export TRANSFORMERS_CACHE=/apdcephfs_sh2/share_300000800/user/zongxia/hf_cache
6
+
7
+ ./move_eval.sh
8
+
9
+ model_paths=(
10
+ # "Qwen/Qwen2.5-VL-3B-Instruct"
11
+ # "/apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/src/r1-v/log/3B-Video-GRPO-NoDesEval/checkpoint-1000"
12
+ # "/apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/src/r1-v/log/3B-Video-GRPO-selfEval-ThenNoDesEval/pool_numerical_chunk_02/checkpoint-42"
13
+ # "/apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/src/r1-v/log/3B-Video-GRPO-AnswerBERT/video_pool_multiple_choice_chunk_02/checkpoint-46"
14
+ # "Video-R1/Video-R1-7B"
15
+ "zli12321/VideoHallu-R1-v3"
16
+ # "Qwen/Qwen2.5-VL-7B-Instruct"
17
+ )
18
+
19
+ file_names=(
20
+ # "qwen_3B_base"
21
+ # "qwen_3B_noDesEval"
22
+ # "qwen_3B_answerBERT_thenNoDesEval"
23
+ # "qwen_3B_answerBERT_video12"
24
+ # "video-R1-7B"
25
+ "VideoHallu-R1-v3"
26
+ # "Qwen2.5-VL-7B-Instruct"
27
+ )
28
+
29
+ export DECORD_EOF_RETRY_MAX=20480
30
+
31
+
32
+ for i in "${!model_paths[@]}"; do
33
+ model="${model_paths[$i]}"
34
+ file_name="${file_names[$i]}"
35
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python ./src/eval_bench.py --model_path "$model" --file_name "$file_name"
36
+ done
src/eval_bench_4567.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # run_models.sh
3
+
4
+ # export HF_HOME=/apdcephfs_sh2/share_300000800/user/zongxia/hf_cache
5
+ # export TRANSFORMERS_CACHE=/apdcephfs_sh2/share_300000800/user/zongxia/hf_cache
6
+
7
+ ./move_eval.sh
8
+
9
+ model_paths=(
10
+ # "Qwen/Qwen2.5-VL-7B-Instruct"
11
+ # "/apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/src/r1-v/log/3B-Video-GRPO-AnswerBERT/video_pool_multiple_choice_chunk_01/checkpoint-46"
12
+ # "/apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/src/r1-v/log/3B-Video-GRPO-SelfEval-Train/pool_numerical_chunk_01/checkpoint-25"
13
+ # "/apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/src/r1-v/log/3B-Video-GRPO-NoDesEval/pool_multiple_choice_chunk_01/checkpoint-57"
14
+ # "Video-R1/Qwen2.5-VL-7B-COT-SFT"
15
+ "zli12321/VideoHallu-R1-v1.0"
16
+ )
17
+
18
+ file_names=(
19
+ # "qwen_3B_selfEval_mcq1_nume1"
20
+ # "qwen_3B_NoDesEval_mcq1"
21
+ # "Video-R1-7B-COT-SFT"
22
+ "VideoHallu-R1-v1.0"
23
+ )
24
+
25
+ export DECORD_EOF_RETRY_MAX=20480
26
+
27
+
28
+ for i in "${!model_paths[@]}"; do
29
+ model="${model_paths[$i]}"
30
+ file_name="${file_names[$i]}"
31
+ CUDA_VISIBLE_DEVICES=4,5,6,7 python ./src/eval_bench.py --model_path "$model" --file_name "$file_name"
32
+ done
src/generate_cot_vllm.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from tqdm import tqdm
5
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
6
+ from rouge_score import rouge_scorer
7
+ import torch
8
+
9
+ from transformers import AutoProcessor, AutoTokenizer
10
+ from vllm import LLM, SamplingParams
11
+ from qwen_vl_utils import process_vision_info
12
+
13
+
14
+ MODEL_PATH = "Qwen/Qwen2.5-VL-72B-Instruct"
15
+ BSZ = 32
16
+
17
+
18
+ llm = LLM(
19
+ model=MODEL_PATH,
20
+ tensor_parallel_size=torch.cuda.device_count(),
21
+ max_model_len = 8192,
22
+ gpu_memory_utilization=0.8,
23
+ limit_mm_per_prompt={"image": 10, "video": 10},
24
+ )
25
+
26
+ sampling_params = SamplingParams(
27
+ temperature=1.0,
28
+ top_p=0.95,
29
+ max_tokens=512,
30
+ stop_token_ids=[],
31
+ )
32
+
33
+
34
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
35
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
36
+ tokenizer.padding_side = "left"
37
+ processor.tokenizer = tokenizer
38
+
39
+ for dataset_name in ['your_data_name']:
40
+
41
+ OUTPUT_PATH = f"./src/r1-v/Video-R1-data/{dataset_name}_COT_qwen72b.json"
42
+ PROMPT_PATH = f"./src/r1-v/Video-R1-data/{dataset_name}.json"
43
+
44
+ data = []
45
+ if PROMPT_PATH.endswith('.jsonl'):
46
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
47
+ for line in f:
48
+ data.append(json.loads(line))
49
+ elif PROMPT_PATH.endswith('.json'):
50
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
51
+ data = json.load(f)
52
+ else:
53
+ raise ValueError("Input file must be .json or .jsonl")
54
+
55
+
56
+ QUESTION_TEMPLATE = (
57
+ "{Question}\n"
58
+ "Please think about this question as if you were a human pondering deeply. "
59
+ "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
60
+ "It's encouraged to include self-reflection or verification in the reasoning process. "
61
+ "Provide your detailed reasoning between the <think> and </think> tags, and then give your final answer between the <answer> and </answer> tags."
62
+ )
63
+
64
+ TYPE_TEMPLATE = {
65
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
66
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
67
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
68
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
69
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
70
+ }
71
+
72
+
73
+ messages = []
74
+ for x in data:
75
+ if x["problem_type"] == 'multiple choice':
76
+ question = x['problem'] + "Options:\n"
77
+ for op in x["options"]:
78
+ question += op + "\n"
79
+ else:
80
+ question = x['problem']
81
+
82
+ msg = [{
83
+ "role": "user",
84
+ "content": [
85
+ {
86
+ "type": x['data_type'],
87
+ x['data_type']: os.getcwd() + "/src/r1-v/Video-R1-data" + x['path'][1:]
88
+ },
89
+ {
90
+ "type": "text",
91
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[x['problem_type']]
92
+ }
93
+ ]
94
+ }]
95
+ messages.append(msg)
96
+
97
+ # For resume
98
+ final_output = []
99
+ start_idx = 0
100
+ if os.path.exists(OUTPUT_PATH):
101
+ try:
102
+ with open(OUTPUT_PATH, "r", encoding="utf-8") as f:
103
+ existing = json.load(f)
104
+ final_output = existing.get("results", [])
105
+ start_idx = len(final_output)
106
+ print(f"Resuming from sample index {start_idx}")
107
+ except Exception as e:
108
+ print(f"Error reading existing output file: {e}")
109
+
110
+ def extract_think(output_str):
111
+ pattern = r'<think>\s*(.*?)\s*</think>'
112
+ match = re.search(pattern, output_str, re.DOTALL)
113
+ if match:
114
+ return match.group(1).strip()
115
+ return ""
116
+
117
+ def extract_answer(text):
118
+ pattern = r'<answer>\s*(.*?)\s*</answer>'
119
+ match = re.search(pattern, text, re.DOTALL)
120
+ if match:
121
+ return match.group(1).strip()
122
+ return ""
123
+
124
+ def normalize_number(num_str):
125
+ try:
126
+ num_str = num_str.replace(',', '')
127
+ return float(num_str)
128
+ except Exception as e:
129
+ print(f"Error converting '{num_str}' to float: {e}")
130
+ return None
131
+
132
+ def wer(reference, hypothesis):
133
+ ref_words = reference.split()
134
+ hyp_words = hypothesis.split()
135
+ m = len(ref_words)
136
+ n = len(hyp_words)
137
+ d = [[0]*(n+1) for _ in range(m+1)]
138
+ for i in range(m+1):
139
+ d[i][0] = i
140
+ for j in range(n+1):
141
+ d[0][j] = j
142
+ for i in range(1, m+1):
143
+ for j in range(1, n+1):
144
+ if ref_words[i-1] == hyp_words[j-1]:
145
+ d[i][j] = d[i-1][j-1]
146
+ else:
147
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
148
+ return d[m][n] / max(1, m)
149
+
150
+ def compute_bleu_score(reference, hypothesis):
151
+ try:
152
+ smoothing = SmoothingFunction().method1
153
+ ref_tokens = reference.split()
154
+ hyp_tokens = hypothesis.split()
155
+ score = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smoothing)
156
+ return score
157
+ except Exception as e:
158
+ print(f"Error computing BLEU score: {e}")
159
+ return 0.0
160
+
161
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
162
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
163
+ scores = scorer.score(reference, hypothesis)
164
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
165
+ return average_fmeasure
166
+
167
+ def reward_fn(sample, model_output, question_type):
168
+ try:
169
+ output_ans = extract_answer(model_output)
170
+ gt_ans = extract_answer(sample.get("solution", ""))
171
+ if question_type == "multiple choice":
172
+ return 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
173
+ elif question_type == "numerical":
174
+ gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
175
+ out_has_decimal = ("." in output_ans) or ("," in output_ans)
176
+ if gt_has_decimal != out_has_decimal:
177
+ return 0.0
178
+ gt_number = normalize_number(gt_ans)
179
+ out_number = normalize_number(output_ans)
180
+ if gt_number is None or out_number is None:
181
+ return 0.0
182
+ return 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
183
+ elif question_type == "OCR":
184
+ error_rate = wer(gt_ans, output_ans)
185
+ reward = 1 - error_rate
186
+ return max(0.0, min(1.0, reward))
187
+ elif question_type == "free-form":
188
+ score = compute_rouge_score(gt_ans, output_ans)
189
+ return max(0.0, min(1.0, score))
190
+ elif question_type == "regression":
191
+ gt_number = normalize_number(gt_ans)
192
+ out_number = normalize_number(output_ans)
193
+ if gt_number is None or out_number is None:
194
+ return 0.0
195
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
196
+ rel_diff = min(1.0, max(0.0, rel_diff))
197
+ return 1 - rel_diff
198
+ else:
199
+ return 0.0
200
+ except Exception as e:
201
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
202
+ return 0.0
203
+
204
+
205
+ for i in tqdm(range(start_idx, len(messages), BSZ), desc="Processing batches"):
206
+ batch_messages = messages[i:i + BSZ]
207
+
208
+ prompts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
209
+
210
+ try:
211
+ image_inputs, video_inputs, video_kwargs = process_vision_info(batch_messages, return_video_kwargs=True)
212
+
213
+ image_idx = 0
214
+ video_idx = 0
215
+
216
+ llm_inputs = []
217
+
218
+
219
+ for idx, prompt in enumerate(prompts):
220
+ mm_type = batch_messages[idx][0]['content'][0]['type']
221
+ sample_mm_data = {}
222
+ sample_video_kw = {}
223
+ if mm_type == 'image':
224
+ sample_mm_data["image"] = image_inputs[image_idx]
225
+ image_idx += 1
226
+ elif mm_type == 'video':
227
+ sample_mm_data["video"] = video_inputs[video_idx]
228
+ for key, value in video_kwargs.items():
229
+ sample_video_kw[key] = value[video_idx]
230
+ video_idx += 1
231
+
232
+
233
+ llm_inputs.append({
234
+ "prompt": prompt,
235
+ "multi_modal_data": sample_mm_data,
236
+ "mm_processor_kwargs": sample_video_kw,
237
+ })
238
+
239
+
240
+ outputs = llm.generate(llm_inputs, sampling_params=sampling_params)
241
+ batch_output_text = [out.outputs[0].text for out in outputs]
242
+
243
+ except Exception as e:
244
+ print('error:', data[i]['path'])
245
+ batch_output_text = ['<answer>error</answer>'] * BSZ
246
+
247
+
248
+ for j, (sample, model_output) in enumerate(zip(data[i:i+BSZ], batch_output_text), start=i):
249
+ think_chain = extract_think(model_output)
250
+ final_ans = extract_answer(model_output)
251
+ sample["answer"] = final_ans
252
+ q_type = sample.get("problem_type", "")
253
+ sample["reward"] = reward_fn(sample, model_output, q_type)
254
+ sample['select'] = True if sample["reward"] > 0.6 else False
255
+ if think_chain:
256
+ sample["process"] = f"<think>{think_chain}</think>"
257
+ final_output.append(sample)
258
+
259
+ try:
260
+ with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
261
+ json.dump({"results": final_output}, f, indent=2, ensure_ascii=False)
262
+ print(f"Processed batch {(i - start_idx)//BSZ + 1}, saved {len(final_output)} samples.")
263
+ except Exception as e:
264
+ print(f"Error writing to output file: {e}")
265
+
266
+ print(f"Results saved to {OUTPUT_PATH}")
src/inference_example.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from vllm import LLM, SamplingParams
4
+ from transformers import AutoProcessor, AutoTokenizer
5
+ from qwen_vl_utils import process_vision_info
6
+
7
+ # Set model path
8
+ model_path = "Video-R1/Video-R1-7B"
9
+
10
+ # Set video path and question
11
+ video_path = "./src/example_video/video1.mp4"
12
+ question = "Which move motion in the video lose the system energy?"
13
+
14
+ # Choose the question type from 'multiple choice', 'numerical', 'OCR', 'free-form', 'regression'
15
+ problem_type = 'free-form'
16
+
17
+ # Initialize the LLM
18
+ llm = LLM(
19
+ model=model_path,
20
+ tensor_parallel_size=1,
21
+ max_model_len=81920,
22
+ gpu_memory_utilization=0.8,
23
+ limit_mm_per_prompt={"video": 1, "image": 1},
24
+ )
25
+
26
+ sampling_params = SamplingParams(
27
+ temperature=0.1,
28
+ top_p=0.001,
29
+ max_tokens=1024,
30
+ )
31
+
32
+ # Load processor and tokenizer
33
+ processor = AutoProcessor.from_pretrained(model_path)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
35
+ tokenizer.padding_side = "left"
36
+ processor.tokenizer = tokenizer
37
+
38
+ # Prompt template
39
+ QUESTION_TEMPLATE = (
40
+ "{Question}\n"
41
+ "Please think about this question as if you were a human pondering deeply. "
42
+ "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
43
+ "It's encouraged to include self-reflection or verification in the reasoning process. "
44
+ "Provide your detailed reasoning between the <think> and </think> tags, and then give your final answer between the <answer> and </answer> tags."
45
+ )
46
+
47
+ # Question type
48
+ TYPE_TEMPLATE = {
49
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
50
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
51
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
52
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
53
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
54
+ }
55
+
56
+ # Construct multimodal message
57
+ messages = [
58
+ {
59
+ "role": "user",
60
+ "content": [
61
+ {
62
+ "type": "video",
63
+ "video": video_path,
64
+ "max_pixels": 200704, # max pixels for each frame
65
+ "nframes": 32 # max frame number
66
+ },
67
+ {
68
+ "type": "text",
69
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[problem_type]
70
+ },
71
+ ],
72
+ }
73
+ ]
74
+
75
+ # Convert to prompt string
76
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
77
+
78
+ # Process video input
79
+ image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
80
+
81
+ # Prepare vllm input
82
+ llm_inputs = [{
83
+ "prompt": prompt,
84
+ "multi_modal_data": {"video": video_inputs[0]},
85
+ "mm_processor_kwargs": {key: val[0] for key, val in video_kwargs.items()},
86
+ }]
87
+
88
+ # Run inference
89
+ outputs = llm.generate(llm_inputs, sampling_params=sampling_params)
90
+ output_text = outputs[0].outputs[0].text
91
+
92
+ print(output_text)
93
+
src/scripts/run_grpo_video.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/r1-v
2
+
3
+ export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
4
+ export LOG_PATH="./debug_log_2b.txt"
5
+
6
+ # For resume training: --resume_from_checkpoint Model_Path \
7
+ # Set temporal to choose between T-GRPO and GRPO, and len_control to enable or disable the length control reward.
8
+
9
+ # Qwen/Qwen2.5-VL-3B-Instruct
10
+
11
+ CUDA_VISIBLE_DEVICES=0,1,2,3,5,6,7 torchrun --nproc_per_node="8" \
12
+ --nnodes="1" \
13
+ --node_rank="0" \
14
+ --master_addr="127.0.0.1" \
15
+ --master_port="12365" \
16
+ src/open_r1/grpo.py \
17
+ --output_dir "./log/Qwen2.5-VL-3B-GRPO" \
18
+ --model_name_or_path 'Qwen/Qwen2.5-VL-3B-Instruct' \
19
+ --dataset_name "./Video-R1-data/Video-R1-260k.json" \
20
+ --deepspeed local_scripts/zero3.json \
21
+ --max_prompt_length 16384 \
22
+ --max_completion_length 768 \
23
+ --per_device_train_batch_size 1 \
24
+ --gradient_accumulation_steps 1 \
25
+ --learning_rate 1e-6 \
26
+ --lr_scheduler_type "cosine" \
27
+ --weight_decay 0.01 \
28
+ --bf16 \
29
+ --logging_steps 1 \
30
+ --gradient_checkpointing true \
31
+ --temporal true \
32
+ --len_control true \
33
+ --attn_implementation flash_attention_2 \
34
+ --max_pixels 401408 \
35
+ --num_train_epochs 1 \
36
+ --run_name Video-R1 \
37
+ --save_steps 100 \
38
+ --beta 0.04 \
39
+ --max_grad_norm 5 \
40
+ --save_only_model false \
41
+ --num_generations 8 # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance
42
+
43
+
44
+ python /apdcephfs_sh2/share_300000800/user/zongxia/Video-R1/gpu_burn.py