diff --git a/.gitattributes b/.gitattributes
index 45dc233c563f29828647ea33a44da9e56ac538fe..3c0b6ea515ec434cadf4b0e06c56ada063f1e0ab 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -45,3 +45,4 @@ previous_version/Video-R1-main-previous/images/sample.png filter=lfs diff=lfs me
previous_version/Video-R1-main-previous/images/CATER_new_003595.gif filter=lfs diff=lfs merge=lfs -text
previous_version/Video-R1-main-previous/images/2B_curve.png filter=lfs diff=lfs merge=lfs -text
previous_version/Video-R1-main-previous/images/7B_curve.png filter=lfs diff=lfs merge=lfs -text
+src/example_video/video1.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/previous_version/Video-R1-main-previous/README.md b/previous_version/Video-R1-main-previous/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5ba8e171cf6ba52309bfff84ad39943c900d06dd
--- /dev/null
+++ b/previous_version/Video-R1-main-previous/README.md
@@ -0,0 +1,143 @@
+# Video-R1: Towards Super Reasoning Ability in Video Understanding
+
+This work aims to integrate deep thinking capabilities into video understanding tasks through the R1 paradigm.
+
+For the first time, we achieved a simultaneous increase in both accuracy and thinking length in video understanding domain.
+
+This is a preliminary repo, and we will continue to develop our Video-R1 model in the future.
+
+## Updates
+- [2025/02/23] We release training code and data of Video-R1
+
+
+
+## Findings
+
+### *Shared Growth of Accuracy and Thinking Length is Possible in Video*
+
+In many previous multimodal R1 repositories, the thinking length either showed little to no increase (e.g., [Open R1 Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video?tab=readme-ov-file) ) or even decreased (e.g., [R1-V](https://github.com/Deep-Agent/R1-V) ).
+
+In this work, we demonstrate that this issue can be addressed by using an appropriate base model and a strong reasoning dataset. We train [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) using GRPO with accuracy and format rewards on the [DVD-counting](https://huggingface.co/datasets/Video-R1/DVD-counting) dataset. Training the 7B model for 900 steps can be completed in approximately 10 hours using 4 x A100 (80G) GPUs. The training curve is as follows:
+
+
+
+
+
+### *Weak Base Model Hinders the Emergence of Deep Thinking in Video*
+
+We train [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) using the same setting on the DVD-counting dataset. In contrast, this model shows a decrease in thinking length.
+
+In some cases, the model even skips the thinking process and outputs sentences like this: `\n\n2`.
+
+
+
+
+
+
+
+
+### *Weak Reasoning Data Maybe Not Beneficial for Reinforcing Deep Thinking*
+
+We train [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) on a subset of NExT-QA dataset with little reasoning. We can notice that there is almost no increase in thinking length. This indicates that reinforcing deep thinking may require strong reasoning data.
+
+
+
+
+
+
+
+
+
+## Datasets
+
+The video files are in the zip file and the train/test splits are in the jsonl file.
+
+[🤗 Video-R1 Dataset: DVD-counting](https://huggingface.co/datasets/Video-R1/DVD-counting)
+
+This dataset is extracted from "DVD: A Diagnostic Dataset for Multi-step Reasoning in Video Grounded Dialogue"
+
+## Performance
+
+We can observe that RL training results in an accuracy boost of around 10% on DVD-counting-test
+
+
+
+
+| Dataset | Qwen2-VL-7B-Instruct | Video-R1-7B |
+| ----------------- | -------------------- | ----------- |
+| DVD-counting-test | 25.0 | 34.5 |
+
+
+
+Reasoning Samples:
+
+
+
+

+
+
+
+
+

+
+
+
+
+
+
+
+## Set up
+
+```bash
+git clone https://github.com/tulerfeng/Video-R1
+cd Video-R1
+
+# build environment
+conda create -n video-r1 python=3.11
+conda activate video-r1
+bash setup.sh
+
+# qwen video extraction setting
+cd src/qwen-vl-utils
+pip install -e .
+cd ..
+
+# download dataset
+git lfs install
+git clone https://huggingface.co/datasets/Video-R1/DVD-counting
+```
+
+Please put the downloaded dataset to `src/r1-v/data/`
+
+## Training
+
+Train Qwen2-VL-7B-Instruct with GRPO
+
+```bash
+bash src/scripts/run_grpo_video.sh
+```
+
+
+
+## Evaluation
+
+Evaluation on video counting task
+
+```bash
+python ./src/eval/test_qwen2vl_video_counting.py
+```
+
+
+
+## Acknowledgements
+
+We sincerely appreciate the contributions of the open-source community. The related projects are as follows:
+
++ [R1-V](https://github.com/Deep-Agent/R1-V) (our initial codebase)
++ [Open R1 Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video?tab=readme-ov-file) (concurrent work)
+
+- [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal)
+- [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1)
+- [open-r1](https://github.com/huggingface/open-r1)
+
+
diff --git a/previous_version/Video-R1-main-previous/setup.sh b/previous_version/Video-R1-main-previous/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..adcde9530e10f89165665927e7bb036ffe6fd479
--- /dev/null
+++ b/previous_version/Video-R1-main-previous/setup.sh
@@ -0,0 +1,15 @@
+# Install the packages in r1-v .
+cd src/r1-v
+pip install -e ".[dev]"
+
+# Addtional modules
+pip install wandb==0.18.3
+pip install tensorboardx
+pip install qwen_vl_utils torchvision
+pip install flash-attn --no-build-isolation
+
+# vLLM support
+pip install vllm==0.7.2
+
+# fix transformers version
+pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef
\ No newline at end of file
diff --git a/previous_version/Video-R1-main-previous/src/scripts/run_grpo_video.sh b/previous_version/Video-R1-main-previous/src/scripts/run_grpo_video.sh
new file mode 100644
index 0000000000000000000000000000000000000000..63904e61f126ff8d5a0dc8f9b26c1e8b3d62dc57
--- /dev/null
+++ b/previous_version/Video-R1-main-previous/src/scripts/run_grpo_video.sh
@@ -0,0 +1,34 @@
+cd src/r1-v
+
+export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
+export LOG_PATH="./debug_log_2b.txt"
+
+
+
+CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node="4" \
+ --nnodes="1" \
+ --node_rank="0" \
+ --master_addr="127.0.0.1" \
+ --master_port="12351" \
+ src/open_r1/grpo.py \
+ --output_dir "YOUR_PATH/log_dvd" \
+ --model_name_or_path "Qwen/Qwen2-VL-7B-Instruct" \
+ --dataset_name "YOUR_PATH/data/train_dvd.jsonl" \
+ --deepspeed local_scripts/zero3.json \
+ --max_prompt_length 4096 \
+ --max_completion_length 512 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --learning_rate 1e-6 \
+ --logging_steps 1 \
+ --bf16 \
+ --report_to wandb \
+ --gradient_checkpointing true \
+ --attn_implementation flash_attention_2 \
+ --max_pixels 401408 \
+ --num_train_epochs 2 \
+ --run_name Qwen2-VL-7B-Video-dvd \
+ --save_steps 100 \
+ --max_grad_norm 20 \
+ --save_only_model true \
+ --num_generations 8 # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance
diff --git a/previous_version/Video-R1-main-previous/src/scripts/run_grpo_vllm.sh b/previous_version/Video-R1-main-previous/src/scripts/run_grpo_vllm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..65438b2cdc41dc05bf98987b2d085967f420aee3
--- /dev/null
+++ b/previous_version/Video-R1-main-previous/src/scripts/run_grpo_vllm.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+# The latest vllm==0.7.2 is required for this script: pip3 install vllm==0.7.2
+
+
+export DEBUG_MODE="true"
+export LOG_PATH="./vllm_run.txt"
+
+QWEN_PATH="PATH_TO_QWEN_2B_CKPT"
+HF_DATASET="MMInstruction/Clevr_CoGenT_TrainA_70K_Complex"
+OUTPUT_DIR="OUTPUT_DIR"
+RUN_NAME="RUN_NAME_FOR_WANDB"
+
+# NOTE: you are expected to use X + 1 cards for X training proc and 1 vLLM proc
+# e.g., the visible devices should be 0,1,2,3,4 for 5 cards, and --nproc_per_node="4"
+
+CUDA_VISIBLE_DEVICES="0,1,2,3,4" torchrun --nproc_per_node="4" \
+ --nnodes="1" \
+ --node_rank="0" \
+ --master_addr="127.0.0.1" \
+ --master_port="12345" \
+ src/open_r1/grpo.py --use_vllm True \
+ --output_dir $OUTPUT_DIR \
+ --model_name_or_path $QWEN_PATH \
+ --dataset_name $HF_DATASET \
+ --max_prompt_length 512 \
+ --max_completion_length 1024 \
+ --temperature 1.0 \
+ --num_generations 4 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 4 \
+ --logging_steps 1 \
+ --bf16 \
+ --report_to wandb \
+ --gradient_checkpointing true \
+ --attn_implementation flash_attention_2 \
+ --max_pixels 400000 \
+ --max_steps 13125 \
+ --run_name $RUN_NAME \
+ --save_steps 1000 \
+ --save_only_model true
diff --git a/previous_version/Video-R1-main-previous/src/scripts/run_sft_clevr.sh b/previous_version/Video-R1-main-previous/src/scripts/run_sft_clevr.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4b33dbf45bba6bf935bd12a7f0df15f274932526
--- /dev/null
+++ b/previous_version/Video-R1-main-previous/src/scripts/run_sft_clevr.sh
@@ -0,0 +1 @@
+ACCELERATE_LOG_LEVEL=info accelerate launch --config_file src/r1-v/configs/zero2.yaml src/r1-v/src/open_r1/sft.py --config src/r1-v/configs/qwen2vl_sft_config.yaml
diff --git a/previous_version/Video-R1-main-previous/src/scripts/test_grpo_geoqa_multigpu.sh b/previous_version/Video-R1-main-previous/src/scripts/test_grpo_geoqa_multigpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ad92372b0844355fed4e1bfc922889a8a43d6f19
--- /dev/null
+++ b/previous_version/Video-R1-main-previous/src/scripts/test_grpo_geoqa_multigpu.sh
@@ -0,0 +1,15 @@
+r1_v_path=/workspace/xxx/github/R1-V
+cd ${r1_v_path}
+
+model_path=${r1_v_path}/output/train@geo170k/checkpoint-30
+batch_size=4
+output_path=${r1_v_path}/output/train@geo170k/eval/res@checkpoint-30.json
+prompt_path=${r1_v_path}/src/eval/prompts/geoqa_test_prompts.jsonl
+gpu_ids=0,1,2,3,4,5,6,7
+
+python src/eval/test_qwen2vl_geoqa_multigpu.py \
+ --model_path ${model_path} \
+ --batch_size ${batch_size} \
+ --output_path ${output_path} \
+ --prompt_path ${prompt_path} \
+ --gpu_ids ${gpu_ids}
diff --git a/src/example_video/video1.mp4 b/src/example_video/video1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..09907520d15b358527c7cfb5dc9c1df05565d9ec
--- /dev/null
+++ b/src/example_video/video1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fbd07ed2f5a7289459baf24ebf1228ac0303977950eb623031f7d9bc4e51987
+size 1094692
diff --git a/src/qwen-vl-utils/.python-version b/src/qwen-vl-utils/.python-version
new file mode 100644
index 0000000000000000000000000000000000000000..143c2f5d0b57eae26fc9dec0697e64d7e051ab6c
--- /dev/null
+++ b/src/qwen-vl-utils/.python-version
@@ -0,0 +1 @@
+3.8.19
diff --git a/src/qwen-vl-utils/README.md b/src/qwen-vl-utils/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e4c88d7d71be1d33fbc559165b95e229547301c
--- /dev/null
+++ b/src/qwen-vl-utils/README.md
@@ -0,0 +1,94 @@
+# qwen-vl-utils
+
+Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
+
+## Install
+
+```bash
+pip install qwen-vl-utils
+```
+
+## Usage
+
+### Qwen2VL
+
+```python
+from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+
+# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
+messages = [
+ # Image
+ ## Local file path
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
+ ## Image URL
+ [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
+ ## Base64 encoded image
+ [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
+ ## PIL.Image.Image
+ [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
+ ## Model dynamically adjusts image size, specify dimensions if required.
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
+ # Video
+ ## Local video path
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
+ ## Local video frames
+ [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
+ ## Model dynamically adjusts video nframes, video height and width. specify args if required.
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
+]
+
+processor = AutoProcessor.from_pretrained(model_path)
+model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
+text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+images, videos = process_vision_info(messages)
+inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
+print(inputs)
+generated_ids = model.generate(**inputs)
+print(generated_ids)
+```
+
+### Qwen2.5VL
+
+```python
+from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+
+# You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
+# based on the maximum tokens that the model can accept.
+# export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
+
+
+# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
+messages = [
+ # Image
+ ## Local file path
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
+ ## Image URL
+ [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
+ ## Base64 encoded image
+ [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
+ ## PIL.Image.Image
+ [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
+ ## Model dynamically adjusts image size, specify dimensions if required.
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
+ # Video
+ ## Local video path
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
+ ## Local video frames
+ [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
+ ## Model dynamically adjusts video nframes, video height and width. specify args if required.
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
+]
+
+processor = AutoProcessor.from_pretrained(model_path)
+model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
+text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
+inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
+print(inputs)
+generated_ids = model.generate(**inputs)
+print(generated_ids)
+```
\ No newline at end of file
diff --git a/src/qwen-vl-utils/pyproject.toml b/src/qwen-vl-utils/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..64bd8a19954fe5376d4b92aa139215e0e392908c
--- /dev/null
+++ b/src/qwen-vl-utils/pyproject.toml
@@ -0,0 +1,75 @@
+[project]
+name = "qwen-vl-utils"
+version = "0.0.10"
+description = "Qwen Vision Language Model Utils - PyTorch"
+authors = [
+ { name = "Qwen Team", email = "chenkeqin.ckq@alibaba-inc.com" },
+]
+dependencies = [
+ "requests",
+ "pillow",
+ "av",
+ "packaging",
+]
+readme = "README.md"
+requires-python = ">= 3.8"
+license = {text = "Apache-2.0"}
+keywords = [
+ 'large language model',
+ 'vision language model',
+ 'qwen-vl',
+ 'pytorch',
+]
+classifiers = [
+ 'Development Status :: 4 - Beta',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Programming Language :: Python :: 3',
+ 'License :: OSI Approved :: Apache Software License',
+]
+
+[project.urls]
+Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
+Repository = "https://github.com/QwenLM/Qwen2-VL.git"
+Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
+
+[project.optional-dependencies]
+decord = [
+ "decord",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.rye]
+managed = true
+dev-dependencies = [
+ "torch",
+ "torchvision",
+]
+
+[tool.hatch.metadata]
+allow-direct-references = true
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/qwen_vl_utils"]
+
+[tool.ruff]
+line-length = 119
+
+[tool.ruff.lint]
+ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
+select = ["C", "E", "F", "I", "W"]
+
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = ["qwen_vl_utils"]
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+skip-magic-trailing-comma = false
+line-ending = "auto"
diff --git a/src/qwen-vl-utils/requirements-dev.lock b/src/qwen-vl-utils/requirements-dev.lock
new file mode 100644
index 0000000000000000000000000000000000000000..b6441fe5e0e112a59a2ff472528950bae3877698
--- /dev/null
+++ b/src/qwen-vl-utils/requirements-dev.lock
@@ -0,0 +1,84 @@
+# generated by rye
+# use `rye lock` or `rye sync` to update this lockfile
+#
+# last locked with the following flags:
+# pre: false
+# features: ["decord"]
+# all-features: false
+# with-sources: false
+# generate-hashes: false
+# universal: false
+
+-e file:.
+av==12.3.0
+ # via qwen-vl-utils
+certifi==2022.12.7
+ # via requests
+charset-normalizer==2.1.1
+ # via requests
+decord==0.6.0
+ # via qwen-vl-utils
+filelock==3.13.1
+ # via torch
+ # via triton
+fsspec==2024.2.0
+ # via torch
+idna==3.4
+ # via requests
+jinja2==3.1.3
+ # via torch
+markupsafe==2.1.5
+ # via jinja2
+mpmath==1.3.0
+ # via sympy
+networkx==3.1
+ # via torch
+numpy==1.24.1
+ # via decord
+ # via torchvision
+nvidia-cublas-cu12==12.1.3.1
+ # via nvidia-cudnn-cu12
+ # via nvidia-cusolver-cu12
+ # via torch
+nvidia-cuda-cupti-cu12==12.1.105
+ # via torch
+nvidia-cuda-nvrtc-cu12==12.1.105
+ # via torch
+nvidia-cuda-runtime-cu12==12.1.105
+ # via torch
+nvidia-cudnn-cu12==9.1.0.70
+ # via torch
+nvidia-cufft-cu12==11.0.2.54
+ # via torch
+nvidia-curand-cu12==10.3.2.106
+ # via torch
+nvidia-cusolver-cu12==11.4.5.107
+ # via torch
+nvidia-cusparse-cu12==12.1.0.106
+ # via nvidia-cusolver-cu12
+ # via torch
+nvidia-nccl-cu12==2.20.5
+ # via torch
+nvidia-nvjitlink-cu12==12.6.68
+ # via nvidia-cusolver-cu12
+ # via nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+ # via torch
+packaging==24.1
+ # via qwen-vl-utils
+pillow==10.2.0
+ # via qwen-vl-utils
+ # via torchvision
+requests==2.28.1
+ # via qwen-vl-utils
+sympy==1.12
+ # via torch
+torch==2.4.0
+ # via torchvision
+torchvision==0.19.0
+triton==3.0.0
+ # via torch
+typing-extensions==4.9.0
+ # via torch
+urllib3==1.26.13
+ # via requests
diff --git a/src/qwen-vl-utils/requirements.lock b/src/qwen-vl-utils/requirements.lock
new file mode 100644
index 0000000000000000000000000000000000000000..6f9f6037aabc5fcddcef89add96150a76c51dd8a
--- /dev/null
+++ b/src/qwen-vl-utils/requirements.lock
@@ -0,0 +1,32 @@
+# generated by rye
+# use `rye lock` or `rye sync` to update this lockfile
+#
+# last locked with the following flags:
+# pre: false
+# features: ["decord"]
+# all-features: false
+# with-sources: false
+# generate-hashes: false
+# universal: false
+
+-e file:.
+av==12.3.0
+ # via qwen-vl-utils
+certifi==2022.12.7
+ # via requests
+charset-normalizer==2.1.1
+ # via requests
+decord==0.6.0
+ # via qwen-vl-utils
+idna==3.4
+ # via requests
+numpy==1.24.4
+ # via decord
+packaging==24.1
+ # via qwen-vl-utils
+pillow==10.2.0
+ # via qwen-vl-utils
+requests==2.28.1
+ # via qwen-vl-utils
+urllib3==1.26.13
+ # via requests
diff --git a/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py b/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa8708442e93d5ec3a02e863ad7ae833952d199
--- /dev/null
+++ b/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py
@@ -0,0 +1,7 @@
+from .vision_process import (
+ extract_vision_info,
+ fetch_image,
+ fetch_video,
+ process_vision_info,
+ smart_resize,
+)
diff --git a/src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/__init__.cpython-311.pyc b/src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b18c7dc282cb8ec33c439fc16a813ddd975540d
Binary files /dev/null and b/src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/vision_process.cpython-311.pyc b/src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/vision_process.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08418f23142463ceeeff3843444b1a9df121483c
Binary files /dev/null and b/src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/vision_process.cpython-311.pyc differ
diff --git a/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py b/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff7f007958243257ad60746d57494e06b8f47dd6
--- /dev/null
+++ b/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
@@ -0,0 +1,379 @@
+from __future__ import annotations
+
+import base64
+import logging
+import math
+import os
+import sys
+import time
+import warnings
+from functools import lru_cache
+from io import BytesIO
+
+import requests
+import torch
+import torchvision
+from packaging import version
+from PIL import Image
+from torchvision import io, transforms
+from torchvision.transforms import InterpolationMode
+from typing import Optional
+
+
+logger = logging.getLogger(__name__)
+
+IMAGE_FACTOR = 28
+MIN_PIXELS = 4 * 28 * 28
+MAX_PIXELS = 256 * 28 * 28
+MAX_RATIO = 200
+
+# VIDEO_MIN_PIXELS = 128 * 28 * 28
+# VIDEO_MAX_PIXELS = 768 * 28 * 28
+VIDEO_MIN_PIXELS = 128 * 28 * 28
+VIDEO_MAX_PIXELS = 128 * 28 * 28
+FRAME_FACTOR = 2
+FPS = 2.0
+FPS_MIN_FRAMES = 4
+FPS_MAX_FRAMES = 16
+
+# Set the maximum number of video token inputs.
+# Here, 128K represents the maximum number of input tokens for the VLLM model.
+# Remember to adjust it according to your own configuration.
+VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
+logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+def smart_resize(
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
+) -> tuple[int, int]:
+ """
+ Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+ """
+ if max(height, width) / min(height, width) > MAX_RATIO:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = floor_by_factor(height / beta, factor)
+ w_bar = floor_by_factor(width / beta, factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ return h_bar, w_bar
+
+
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
+
+
+def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
+ if "image" in ele:
+ image = ele["image"]
+ else:
+ image = ele["image_url"]
+ image_obj = None
+ if isinstance(image, Image.Image):
+ image_obj = image
+ elif image.startswith("http://") or image.startswith("https://"):
+ response = requests.get(image, stream=True)
+ image_obj = Image.open(BytesIO(response.content))
+ elif image.startswith("file://"):
+ image_obj = Image.open(image[7:])
+ elif image.startswith("data:image"):
+ if "base64," in image:
+ _, base64_data = image.split("base64,", 1)
+ data = base64.b64decode(base64_data)
+ image_obj = Image.open(BytesIO(data))
+ else:
+ image_obj = Image.open(image)
+ if image_obj is None:
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
+ image = to_rgb(image_obj)
+ ## resize
+ if "resized_height" in ele and "resized_width" in ele:
+ resized_height, resized_width = smart_resize(
+ ele["resized_height"],
+ ele["resized_width"],
+ factor=size_factor,
+ )
+ else:
+ width, height = image.size
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=size_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ return image
+
+
+def smart_nframes(
+ ele: dict,
+ total_frames: int,
+ video_fps: int | float,
+) -> int:
+ """calculate the number of frames for video used for model inputs.
+
+ Args:
+ ele (dict): a dict contains the configuration of video.
+ support either `fps` or `nframes`:
+ - nframes: the number of frames to extract for model inputs.
+ - fps: the fps to extract frames for model inputs.
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
+ total_frames (int): the original total number of frames of the video.
+ video_fps (int | float): the original fps of the video.
+
+ Raises:
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
+
+ Returns:
+ int: the number of frames for video used for model inputs.
+ """
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
+ if "nframes" in ele:
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
+ else:
+ fps = ele.get("fps", FPS)
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
+ nframes = total_frames / video_fps * fps
+ if nframes > total_frames:
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
+ return nframes
+
+
+def _read_video_torchvision(
+ ele: dict,
+) -> (torch.Tensor, float):
+ """read video using torchvision.io.read_video
+
+ Args:
+ ele (dict): a dict contains the configuration of video.
+ support keys:
+ - video: the path of video. support "file://", "http://", "https://" and local path.
+ - video_start: the start time of video.
+ - video_end: the end time of video.
+ Returns:
+ torch.Tensor: the video tensor with shape (T, C, H, W).
+ """
+ video_path = ele["video"]
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
+ if "http://" in video_path or "https://" in video_path:
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
+ if "file://" in video_path:
+ video_path = video_path[7:]
+ st = time.time()
+ video, audio, info = io.read_video(
+ video_path,
+ start_pts=ele.get("video_start", 0.0),
+ end_pts=ele.get("video_end", None),
+ pts_unit="sec",
+ output_format="TCHW",
+ )
+ total_frames, video_fps = video.size(0), info["video_fps"]
+ logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
+ video = video[idx]
+ return video, sample_fps
+
+
+def is_decord_available() -> bool:
+ import importlib.util
+
+ return importlib.util.find_spec("decord") is not None
+
+
+def _read_video_decord(
+ ele: dict,
+) -> (torch.Tensor, float):
+ """read video using decord.VideoReader
+
+ Args:
+ ele (dict): a dict contains the configuration of video.
+ support keys:
+ - video: the path of video. support "file://", "http://", "https://" and local path.
+ - video_start: the start time of video.
+ - video_end: the end time of video.
+ Returns:
+ torch.Tensor: the video tensor with shape (T, C, H, W).
+ """
+ import decord
+ video_path = ele["video"]
+ st = time.time()
+ vr = decord.VideoReader(video_path)
+ # TODO: support start_pts and end_pts
+ if 'video_start' in ele or 'video_end' in ele:
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
+ logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
+ video = vr.get_batch(idx).asnumpy()
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
+ return video, sample_fps
+
+
+VIDEO_READER_BACKENDS = {
+ "decord": _read_video_decord,
+ "torchvision": _read_video_torchvision,
+}
+
+FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
+
+
+@lru_cache(maxsize=1)
+def get_video_reader_backend() -> str:
+ if FORCE_QWENVL_VIDEO_READER is not None:
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
+ elif is_decord_available():
+ video_reader_backend = "decord"
+ else:
+ video_reader_backend = "torchvision"
+ print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
+ return video_reader_backend
+
+
+def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
+ if isinstance(ele["video"], str):
+ video_reader_backend = get_video_reader_backend()
+ try:
+ video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
+ except Exception as e:
+ logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
+ video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
+
+ nframes, _, height, width = video.shape
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
+ if max_pixels_supposed > max_pixels:
+ logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
+ max_pixels = min(max_pixels_supposed, max_pixels)
+ if "resized_height" in ele and "resized_width" in ele:
+ resized_height, resized_width = smart_resize(
+ ele["resized_height"],
+ ele["resized_width"],
+ factor=image_factor,
+ )
+ else:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=image_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ video = transforms.functional.resize(
+ video,
+ [resized_height, resized_width],
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True,
+ ).float()
+ if return_video_sample_fps:
+ return video, sample_fps
+ return video
+ else:
+ assert isinstance(ele["video"], (list, tuple))
+ process_info = ele.copy()
+ process_info.pop("type", None)
+ process_info.pop("video", None)
+ images = [
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
+ for video_element in ele["video"]
+ ]
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
+ if len(images) < nframes:
+ images.extend([images[-1]] * (nframes - len(images)))
+ if return_video_sample_fps:
+ return images, process_info.pop("fps", 2.0)
+ return images
+
+
+def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
+ vision_infos = []
+ if isinstance(conversations[0], dict):
+ conversations = [conversations]
+ for conversation in conversations:
+ for message in conversation:
+ if isinstance(message["content"], list):
+ for ele in message["content"]:
+ if (
+ "image" in ele
+ or "image_url" in ele
+ or "video" in ele
+ or ele["type"] in ("image", "image_url", "video")
+ ):
+ vision_infos.append(ele)
+ return vision_infos
+
+
+def process_vision_info(
+ conversations: list[dict] | list[list[dict]],
+ return_video_kwargs: bool = False,
+) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
+
+ vision_infos = extract_vision_info(conversations)
+ ## Read images or videos
+ image_inputs = []
+ video_inputs = []
+ video_sample_fps_list = []
+ for vision_info in vision_infos:
+ if "image" in vision_info or "image_url" in vision_info:
+ image_inputs.append(fetch_image(vision_info))
+ elif "video" in vision_info:
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
+ video_sample_fps_list.append(video_sample_fps)
+ video_inputs.append(video_input)
+ else:
+ raise ValueError("image, image_url or video should in content.")
+ if len(image_inputs) == 0:
+ image_inputs = None
+ if len(video_inputs) == 0:
+ video_inputs = None
+ if return_video_kwargs:
+ return image_inputs, video_inputs, {'fps': video_sample_fps_list}
+ return image_inputs, video_inputs
diff --git a/src/r1-v/Evaluation/check_path_mp4.py b/src/r1-v/Evaluation/check_path_mp4.py
new file mode 100644
index 0000000000000000000000000000000000000000..c79e237cb7bdd786f7aca53a422181a944a10604
--- /dev/null
+++ b/src/r1-v/Evaluation/check_path_mp4.py
@@ -0,0 +1,112 @@
+import json
+import os
+import subprocess
+from tqdm import tqdm
+
+def is_strict_mp4(file_path):
+ """
+ Check the video file's format information using ffprobe.
+ If the 'format_name' contains "mp4", then the file meets the strict mp4 encoding requirements;
+ otherwise, return False along with ffprobe's output information.
+ """
+ command = [
+ "ffprobe",
+ "-v", "error",
+ "-print_format", "json",
+ "-show_format",
+ file_path
+ ]
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
+ if result.returncode != 0:
+ return False, result.stderr
+ try:
+ info = json.loads(result.stdout)
+ format_name = info.get("format", {}).get("format_name", "")
+ tokens = [token.strip() for token in format_name.split(',')]
+ if "mp4" in tokens:
+ return True, result.stdout
+ else:
+ return False, result.stdout
+ except Exception as e:
+ return False, str(e)
+
+def convert_to_mp4(input_file, output_file):
+ """
+ Use ffmpeg to convert the video to MP4 encoding.
+ The output is saved as a temporary file, and if the conversion is successful,
+ the temporary file replaces the output_file.
+ A scale filter is added to ensure the output resolution dimensions are even,
+ preventing errors from libx264.
+ """
+ temp_file = output_file + ".temp.mp4"
+ command = [
+ "ffmpeg",
+ "-y", # Overwrite output file if it exists
+ "-i", input_file, # Input file
+ "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", # Ensure width and height are even numbers
+ "-c:v", "libx264", # Use libx264 for video encoding
+ "-c:a", "aac", # Use AAC for audio encoding
+ temp_file
+ ]
+ print(f"Converting: {input_file} -> {output_file}")
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
+ if result.returncode != 0:
+ print(f"Conversion failed: {input_file}\n{result.stderr}")
+ if os.path.exists(temp_file):
+ os.remove(temp_file)
+ return False
+ else:
+ os.replace(temp_file, output_file)
+ print(f"Conversion succeeded: {output_file}")
+ return True
+
+def find_alternative(file_path):
+ """
+ If the file specified by file_path does not exist, try to find a file with the same base name
+ but with a different extension in the same directory.
+ """
+ dir_name = os.path.dirname(file_path)
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
+ if not os.path.exists(dir_name):
+ return None
+ for candidate in os.listdir(dir_name):
+ candidate_base, candidate_ext = os.path.splitext(candidate)
+ if candidate_base == base_name and candidate_ext.lower() != ".mp4":
+ candidate_full = os.path.join(dir_name, candidate)
+ if os.path.isfile(candidate_full):
+ return candidate_full
+ return None
+
+def process_videos_from_json(json_file):
+ with open(json_file, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+
+ checked_paths = set() # Record the file paths that have been checked
+ for item in tqdm(data, desc="Processing videos", unit="item"):
+ file_path = item.get("path", "").strip()
+ # Skip if the file has already been checked
+ if file_path in checked_paths:
+ continue
+ checked_paths.add(file_path)
+
+ if os.path.exists(file_path):
+ strict, info = is_strict_mp4(file_path)
+ if not strict:
+ print(f"\nVideo does not meet strict mp4 encoding requirements: {file_path}")
+ print("ffprobe output:")
+ print(info)
+ # Convert the existing file to mp4 encoding (overwrite)
+ convert_to_mp4(file_path, file_path)
+ else:
+ # Try to find an alternative file with the same base name but different extension
+ alternative_file = find_alternative(file_path)
+ if alternative_file:
+ print(f"\nFound alternative: {alternative_file}")
+ # Convert the alternative file to mp4 and save with the desired file_path
+ convert_to_mp4(alternative_file, file_path)
+ else:
+ print(f"File does not exist and no alternative found: {file_path}")
+
+if __name__ == "__main__":
+ # Change this to the path of your JSON file
+ process_videos_from_json("eval_mvbench.json")
diff --git a/src/r1-v/configs/ddp.yaml b/src/r1-v/configs/ddp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f0557131aa2c1bded4cb4cfdc1cc58a3b25765b
--- /dev/null
+++ b/src/r1-v/configs/ddp.yaml
@@ -0,0 +1,16 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/r1-v/configs/qwen2vl_sft_config.yaml b/src/r1-v/configs/qwen2vl_sft_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd4351b11f63344c3ba5090ec0682c38ca802943
--- /dev/null
+++ b/src/r1-v/configs/qwen2vl_sft_config.yaml
@@ -0,0 +1,37 @@
+# Model arguments
+model_name_or_path: Qwen/Qwen2-VL-2B-Instruct
+model_revision: main
+torch_dtype: bfloat16
+
+# Data training arguments
+dataset_name: /home/test/test08/fkt/R1-V-main/GEOQA_R1V_Train_8K
+dataset_configs:
+- all
+preprocessing_num_workers: 4
+
+# SFT trainer config
+bf16: true
+do_eval: true
+eval_strategy: "no"
+gradient_accumulation_steps: 4
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+learning_rate: 2.0e-05
+log_level: info
+logging_steps: 5
+logging_strategy: steps
+lr_scheduler_type: cosine
+packing: true
+max_seq_length: 4096
+max_steps: -1
+num_train_epochs: 1
+output_dir: ./log/Qwen2-VL-2B-Instruct-SFT
+overwrite_output_dir: true
+per_device_eval_batch_size: 1
+per_device_train_batch_size: 1
+report_to:
+- wandb
+save_strategy: "no"
+seed: 42
+warmup_ratio: 0.1
\ No newline at end of file
diff --git a/src/r1-v/configs/zero2.yaml b/src/r1-v/configs/zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..92f25e6a85a8de167f023357fade50b978b81acc
--- /dev/null
+++ b/src/r1-v/configs/zero2.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/src/r1-v/configs/zero3.yaml b/src/r1-v/configs/zero3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5a1201f8a2ee8706b63f0f80c664a1fc61a7d9d
--- /dev/null
+++ b/src/r1-v/configs/zero3.yaml
@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/r1-v/eval_results/empty.json b/src/r1-v/eval_results/empty.json
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/r1-v/eval_results/empty.json
@@ -0,0 +1 @@
+
diff --git a/src/r1-v/local_scripts/create_vision_cot_data.py b/src/r1-v/local_scripts/create_vision_cot_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec2d7c245b1ddedc615d97a88cf67d6711d3333
--- /dev/null
+++ b/src/r1-v/local_scripts/create_vision_cot_data.py
@@ -0,0 +1,153 @@
+import argparse
+import base64
+import concurrent.futures
+import io
+import json
+import os
+import random
+import re
+import time
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from io import BytesIO
+from typing import Dict, List
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
+from tqdm import tqdm
+
+import bytedtos
+import seaborn as sns
+import yaml
+from openai import AzureOpenAI
+from PIL import Image
+from pillow_avif import AvifImagePlugin
+
+
+PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
+
+Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
+
+Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
+
+Input Format:
+Original Question: {original_question}
+Original Answer: {original_answer}
+
+Output Format:
+Question: [rewrite the question if necessary]
+Answer: [answer with reasoning steps, including calculations where applicable]
+step-by-step reasoning process
+easy to verify answer
+"""
+
+
+def get_image_data_url(image_input):
+ if isinstance(image_input, str) and image_input.startswith("data:"):
+ return image_input
+
+ if isinstance(image_input, str) and image_input.startswith("http"):
+ image_input = load_image(image_input)
+
+ if isinstance(image_input, str):
+ image_input = Image.open(image_input)
+
+ if not isinstance(image_input, Image.Image):
+ raise ValueError("Unsupported image input type")
+
+ if image_input.mode != "RGB":
+ image_input = image_input.convert("RGB")
+
+ buffer = BytesIO()
+ image_input.save(buffer, format="JPEG")
+ img_bytes = buffer.getvalue()
+ base64_data = base64.b64encode(img_bytes).decode("utf-8")
+ return f"data:image/jpeg;base64,{base64_data}"
+
+
+def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
+ if image is None:
+ return None
+
+ data_url_list = [get_image_data_url(image)]
+ client = AzureOpenAI(
+ azure_endpoint="YOUR_AZURE_ENDPOINT",
+ api_version="2023-07-01-preview",
+ api_key="YOUR_API_KEY",
+ )
+
+ for attempt in range(max_retries):
+ try:
+ messages = [
+ {
+ "role": "system",
+ "content": "You are an expert to analyze the image and provide useful information for users.",
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ ],
+ },
+ ]
+
+ for data_url in data_url_list:
+ messages[1]["content"].insert(
+ 0, {"type": "image_url", "image_url": {"url": data_url}}
+ )
+
+ response = client.chat.completions.create(
+ model="gpt-4o-2024-08-06",
+ messages=messages,
+ temperature=0.2,
+ max_tokens=8192,
+ )
+ return response.choices[0].message.content
+
+ except Exception as e:
+ if attempt == max_retries - 1:
+ raise Exception(
+ f"Failed after {max_retries} attempts. Last error: {str(e)}"
+ )
+ delay = initial_delay * (2**attempt) + random.uniform(
+ 0, 0.1 * initial_delay * (2**attempt)
+ )
+ time.sleep(delay)
+
+
+def process_single_item(example):
+ try:
+ image_path = example["image_path"]
+ formatted_prompt = PROMPT_FORMAT.format(
+ original_question=example["question"], original_answer=example["answer"]
+ )
+
+ response = gpt4o_query(image_path, formatted_prompt)
+ example["gpt4o_response"] = response
+ return example
+ except Exception as e:
+ print(f"Error processing item: {str(e)}")
+ example["gpt4o_response"] = None
+ return example
+
+
+def main():
+ dataset_path = "path/to/your/dataset"
+ full_dataset = load_from_disk(dataset_path)
+
+ processed_dataset = full_dataset.map(
+ function=partial(process_single_item),
+ num_proc=256,
+ desc="Processing dataset with GPT-4o",
+ keep_in_memory=True,
+ )
+
+ output_path = f"{dataset_path}_processed"
+ processed_dataset.save_to_disk(output_path)
+ print(f"Processed dataset saved to: {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/r1-v/local_scripts/lmms_eval_qwen2vl.sh b/src/r1-v/local_scripts/lmms_eval_qwen2vl.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d38769aa91029d63880a5dfc6f9cf64bb36c31a
--- /dev/null
+++ b/src/r1-v/local_scripts/lmms_eval_qwen2vl.sh
@@ -0,0 +1,61 @@
+export HF_HOME=""
+export HF_TOKEN=""
+export HF_HUB_ENABLE_HF_TRANSFER="1"
+
+export API_TYPE=""
+export AZURE_ENDPOINT=""
+export AZURE_API_KEY=""
+export API_VERSION=""
+export MODEL_VERSION=""
+export NAVIT_ATTENTION_IMPLEMENTATION="eager"
+
+# Prompt for installation with 3-second timeout
+read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
+if [ "$install_deps" = "YES" ]; then
+ # Prepare the environment
+ pip3 install --upgrade pip
+ pip3 install -U setuptools
+
+ cd
+ if [ ! -d "maas_engine" ]; then
+ git clone
+ else
+ echo "maas_engine directory already exists, skipping clone"
+ fi
+ cd maas_engine
+ git pull
+ git checkout
+ pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
+
+ current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
+ if [ "$current_version" != "4.46.2" ]; then
+ echo "Installing transformers 4.46.2 (current version: $current_version)"
+ pip3 install transformers==4.46.2
+ else
+ echo "transformers 4.46.2 is already installed"
+ fi
+
+ cd
+ rm -rf
+ pip3 install -e .
+ pip3 install -U pydantic
+ pip3 install Levenshtein
+ pip3 install nltk
+ python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
+fi
+
+TASKS=mmmu_val,mathvista_testmini,mmmu_pro
+MODEL_BASENAME=qwen2_vl
+
+model_checkpoint=""
+echo "MODEL_BASENAME: ${MODEL_BASENAME}"
+cd
+
+python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
+ --model qwen2_vl \
+ --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
+ --tasks ${TASKS} \
+ --batch_size 1 \
+ --log_samples \
+ --log_samples_suffix ${MODEL_BASENAME} \
+ --output_path ./logs
\ No newline at end of file
diff --git a/src/r1-v/local_scripts/prepare_hf_data.py b/src/r1-v/local_scripts/prepare_hf_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..62eab9e0fbba24ce354a10846fb8404abde9feaa
--- /dev/null
+++ b/src/r1-v/local_scripts/prepare_hf_data.py
@@ -0,0 +1,166 @@
+import matplotlib.pyplot as plt
+import seaborn as sns
+import pandas as pd
+import random
+from typing import List, Dict
+import numpy as np
+from concurrent.futures import ThreadPoolExecutor
+from tqdm import tqdm
+import datasets
+
+import io
+from datasets import load_dataset, load_from_disk, concatenate_datasets
+from PIL import Image
+from tqdm import tqdm
+from functools import partial
+from pillow_avif import AvifImagePlugin
+from datasets import Dataset
+import json
+import yaml
+import os
+import re
+import time
+import random
+import base64
+from openai import AzureOpenAI
+import concurrent.futures
+from typing import List, Dict
+import argparse
+import time
+
+
+def extract_problem_solution(gpt4o_response):
+ # Split the response into parts
+ parts = gpt4o_response.split("")
+
+ # Extract the problem (first part before any tags)
+ problem = parts[0].strip()
+ # Remove "Question:" prefix if it exists
+ problem = re.sub(r"^Question:\s*", "", problem)
+ # Remove "Answer:" at the end of the problem
+ problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
+
+ # Combine all the reasoning steps into a single block
+ think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p]
+ solution = f"{' '.join(think_parts)}"
+
+ # Add the final answer if it exists, removing "Answer:" prefix
+ if "" in gpt4o_response:
+ final_answer = (
+ gpt4o_response.split("")[-1].split("")[0].strip()
+ )
+ final_answer = re.sub(r"^Answer:\s*", "", final_answer)
+ solution += f"\n\n{final_answer}"
+
+ return problem, solution
+
+
+def load_image_from_path(image_path):
+ try:
+ img = Image.open(image_path)
+ return img
+ except Exception as e:
+ print(f"Error loading image {image_path}: {str(e)}")
+ return None
+
+
+def process_raw_data(raw_data):
+ # Parse the raw data if it's a string
+ if isinstance(raw_data, str):
+ data = json.loads(raw_data)
+ else:
+ data = raw_data
+
+ # Extract problem and solution
+ try:
+ problem, solution = extract_problem_solution(data["gpt4o_response"])
+ image = load_image_from_path(data["image_path"])
+
+ return {
+ "image": image,
+ "problem": problem,
+ "solution": solution,
+ "original_question": data["question"],
+ "original_answer": data["answer"],
+ }
+ except Exception as e:
+ print(f"Error processing data {data}: {str(e)}")
+ return {
+ "image": None,
+ "problem": None,
+ "solution": None,
+ "original_question": None,
+ "original_answer": None,
+ }
+
+
+raw_data_list = [
+ "/path/to/reasoning_data_with_response_90k_verified",
+]
+
+raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
+
+processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
+
+hf_dict = {
+ "image": [],
+ "problem": [],
+ "solution": [],
+ "original_question": [],
+ "original_answer": [],
+}
+
+for item in tqdm(processed_data):
+ hf_dict["image"].append(item["image"])
+ hf_dict["problem"].append(item["problem"])
+ hf_dict["solution"].append(item["solution"])
+ hf_dict["original_question"].append(item["original_question"])
+ hf_dict["original_answer"].append(item["original_answer"])
+
+
+features = datasets.Features(
+ {
+ "image": datasets.Image(),
+ "problem": datasets.Value("string"),
+ "solution": datasets.Value("string"),
+ "original_question": datasets.Value("string"),
+ "original_answer": datasets.Value("string"),
+ }
+)
+
+
+def has_empty_tags(text):
+ # Pattern to match empty tags like
+ pattern = r"<[^>]+>[^>]+>"
+ return bool(re.search(pattern, text))
+
+
+def has_answer_pattern(text):
+ if "Answer:" in text:
+ return True
+ return False
+
+
+def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
+ # Assuming the image is in a format that can be checked for dimensions
+ # You might need to adjust this depending on how the image is stored in your dataset
+ try:
+ image = example["image"] # or however your image is accessed
+ if isinstance(image, dict) and "height" in image and "width" in image:
+ return image["height"] >= 28 and image["width"] >= 28
+ # If image is a PIL Image or similar
+ return image.height >= 28 and image.width >= 28
+ except:
+ return False
+
+
+ds = datasets.Dataset.from_dict(hf_dict, features=features)
+ds = ds.filter(
+ lambda x: not has_empty_tags(x["solution"])
+ and not has_answer_pattern(x["problem"])
+ and has_valid_image_size(x)
+ and x["image"] is not None,
+ num_proc=128,
+)
+# Push to Hugging Face Hub
+ds.push_to_hub("path/to/your/dataset")
diff --git a/src/r1-v/local_scripts/train_aria_moe.sh b/src/r1-v/local_scripts/train_aria_moe.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5a3b6966c4a40ff4760e4d1cb0d7518448c30fae
--- /dev/null
+++ b/src/r1-v/local_scripts/train_aria_moe.sh
@@ -0,0 +1,68 @@
+#!/bin/bash
+
+export NCCL_BLOCKING_WAIT=0
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=8
+export NCCL_IB_DISABLE=0
+export NCCL_IB_GID_INDEX=3
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_DEBUG=INFO
+
+# CONFIG Huggingface
+# export HF_TOKEN=""
+export HF_TOKEN=""
+export HF_HOME="$HOME/.cache/huggingface"
+export HF_HUB_ENABLE_HF_TRANSFER="1"
+
+export NCCL_DEBUG=INFO
+
+GPUS="0,1,2,3,4,5,6,7"
+
+# 取 worker0 第一个 port
+ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
+port=${ports[0]}
+port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
+
+echo "total workers: ${ARNOLD_WORKER_NUM}"
+echo "cur worker id: ${ARNOLD_ID}"
+echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
+echo "master ip: ${METIS_WORKER_0_HOST}"
+echo "master port: ${port}"
+echo "master port in cmd: ${port_in_cmd}"
+
+# export WANDB_BASE_URL=https://api.wandb.ai
+# export WANDB_API_KEY=""
+# wandb login $WANDB_API_KEY
+
+export WANDB_BASE_URL=https://api.wandb.ai
+export WANDB_PROJECT=vision-reasoning
+export WANDB_API_KEY=""
+export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
+wandb login $WANDB_API_KEY
+
+cd /home/tiger/multimodal-open-r1
+# pip3 install vllm==0.6.6.post1
+pip3 install -e ".[dev]"
+pip3 install wandb==0.18.3
+
+torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
+ --nnodes="${ARNOLD_WORKER_NUM}" \
+ --node_rank="${ARNOLD_ID}" \
+ --master_addr="${METIS_WORKER_0_HOST}" \
+ --master_port="${port_in_cmd}" \
+ src/open_r1/grpo.py \
+ --deepspeed scripts/zero3.json \
+ --output_dir Aria-GRPO-mini_cot_80k \
+ --model_name_or_path rhymes-ai/Aria \
+ --dataset_name luodian/mini_cot_80k \
+ --max_prompt_length 8192 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --logging_steps 1 \
+ --bf16 \
+ --report_to wandb \
+ --gradient_checkpointing true \
+ --attn_implementation eager \
+ --save_total_limit 8 \
+ --num_train_epochs 1 \
+ --run_name $WANDB_RUN_NAME
diff --git a/src/r1-v/local_scripts/train_qwen2_vl.sh b/src/r1-v/local_scripts/train_qwen2_vl.sh
new file mode 100644
index 0000000000000000000000000000000000000000..137310e4438c645bfb6f89f254c50164f23f5a9d
--- /dev/null
+++ b/src/r1-v/local_scripts/train_qwen2_vl.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+
+export NCCL_BLOCKING_WAIT=0
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=8
+export NCCL_IB_DISABLE=0
+export NCCL_IB_GID_INDEX=3
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_DEBUG=INFO
+
+GPUS="0,1,2,3,4,5,6,7"
+
+# 取 worker0 第一个 port
+ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
+port=${ports[0]}
+port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
+
+echo "total workers: ${ARNOLD_WORKER_NUM}"
+echo "cur worker id: ${ARNOLD_ID}"
+echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
+echo "master ip: ${METIS_WORKER_0_HOST}"
+echo "master port: ${port}"
+echo "master port in cmd: ${port_in_cmd}"
+
+# export WANDB_BASE_URL=https://api.wandb.ai
+# export WANDB_API_KEY=""
+# wandb login $WANDB_API_KEY
+
+export WANDB_BASE_URL=https://api.wandb.ai
+export WANDB_PROJECT=vision-reasoning
+export WANDB_API_KEY=""
+export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
+wandb login $WANDB_API_KEY
+
+cd /home/tiger/multimodal-open-r1
+# pip3 install vllm==0.6.6.post1
+pip3 install -e ".[dev]"
+pip3 install wandb==0.18.3
+
+torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
+ --nnodes="${ARNOLD_WORKER_NUM}" \
+ --node_rank="${ARNOLD_ID}" \
+ --master_addr="${METIS_WORKER_0_HOST}" \
+ --master_port="${port_in_cmd}" \
+ src/open_r1/grpo.py \
+ --deepspeed scripts/zero3.json \
+ --output_dir checkpoints/${WANDB_RUN_NAME} \
+ --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
+ --dataset_name luodian/${DATASET_NAME} \
+ --max_prompt_length 8192 \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --logging_steps 1 \
+ --bf16 \
+ --report_to wandb \
+ --gradient_checkpointing true \
+ --attn_implementation flash_attention_2 \
+ --max_pixels 2359296 \
+ --save_total_limit 8 \
+ --num_train_epochs 1 \
+ --run_name $WANDB_RUN_NAME
diff --git a/src/r1-v/local_scripts/zero1_no_optimizer.json b/src/r1-v/local_scripts/zero1_no_optimizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..5f36063c686c8436cdec1569052efb5fdf33f8d4
--- /dev/null
+++ b/src/r1-v/local_scripts/zero1_no_optimizer.json
@@ -0,0 +1,29 @@
+{
+ "zero_optimization": {
+ "stage": 1,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 1e9,
+ "overlap_comm": false,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 1e9,
+ "contiguous_gradients": true
+ },
+ "fp16": {
+ "enabled": "auto",
+ "auto_cast": true,
+ "loss_scale": 0,
+ "initial_scale_power": 32,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 1,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": true
+}
\ No newline at end of file
diff --git a/src/r1-v/local_scripts/zero2.json b/src/r1-v/local_scripts/zero2.json
new file mode 100644
index 0000000000000000000000000000000000000000..b5ba7ebea0f236230a5a41d72ec23ae1f64130d6
--- /dev/null
+++ b/src/r1-v/local_scripts/zero2.json
@@ -0,0 +1,41 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": false,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/r1-v/local_scripts/zero2_1.json b/src/r1-v/local_scripts/zero2_1.json
new file mode 100644
index 0000000000000000000000000000000000000000..80906a4e2ab253a4fd089f50c43f3104c7804b60
--- /dev/null
+++ b/src/r1-v/local_scripts/zero2_1.json
@@ -0,0 +1,41 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": false,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/r1-v/local_scripts/zero3.json b/src/r1-v/local_scripts/zero3.json
new file mode 100644
index 0000000000000000000000000000000000000000..02d343165ec0eec3af55d3285f45911769af6109
--- /dev/null
+++ b/src/r1-v/local_scripts/zero3.json
@@ -0,0 +1,41 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/r1-v/local_scripts/zero3.yaml b/src/r1-v/local_scripts/zero3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5a1201f8a2ee8706b63f0f80c664a1fc61a7d9d
--- /dev/null
+++ b/src/r1-v/local_scripts/zero3.yaml
@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/r1-v/local_scripts/zero3_offload.json b/src/r1-v/local_scripts/zero3_offload.json
new file mode 100644
index 0000000000000000000000000000000000000000..9da12de56b44374047644fe77607a85ced885e7c
--- /dev/null
+++ b/src/r1-v/local_scripts/zero3_offload.json
@@ -0,0 +1,48 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "steps_per_print": 1e5,
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/src/r1-v/log/Qwen2.5-VL-3B-Video-GRPO-LLMEval-Train-QA10K/training_log.txt b/src/r1-v/log/Qwen2.5-VL-3B-Video-GRPO-LLMEval-Train-QA10K/training_log.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ff61564c2532687552387e90fe783e4113f2156b
--- /dev/null
+++ b/src/r1-v/log/Qwen2.5-VL-3B-Video-GRPO-LLMEval-Train-QA10K/training_log.txt
@@ -0,0 +1,306 @@
+W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793]
+W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793] *****************************************
+W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
+W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793] *****************************************
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
+ return importlib.import_module("." + module_name, self.__name__)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
+ return _bootstrap._gcd_import(name[level:], package, level)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "", line 1204, in _gcd_import
+ File "", line 1176, in _find_and_load
+ File "", line 1147, in _find_and_load_unlocked
+ File "", line 690, in _load_unlocked
+ File "", line 940, in exec_module
+ File "", line 241, in _call_with_frames_removed
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in
+ from ...modeling_utils import PreTrainedModel
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in
+ from .integrations.flash_attention import flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in
+ from ..modeling_flash_attention_utils import _flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in
+ from flash_attn.flash_attn_interface import (
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in
+ import flash_attn_2_cuda as flash_attn_gpu
+ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+
+The above exception was the direct cause of the following exception:
+
+Traceback (most recent call last):
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in
+ from transformers import Qwen2VLForConditionalGeneration
+ File "", line 1229, in _handle_fromlist
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
+ value = getattr(module, name)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
+ module = self._get_module(self._class_to_module[name])
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
+ raise RuntimeError(
+RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
+/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
+ return importlib.import_module("." + module_name, self.__name__)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
+ return _bootstrap._gcd_import(name[level:], package, level)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "", line 1204, in _gcd_import
+ File "", line 1176, in _find_and_load
+ File "", line 1147, in _find_and_load_unlocked
+ File "", line 690, in _load_unlocked
+ File "", line 940, in exec_module
+ File "", line 241, in _call_with_frames_removed
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in
+ from ...modeling_utils import PreTrainedModel
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in
+ from .integrations.flash_attention import flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in
+ from ..modeling_flash_attention_utils import _flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in
+ from flash_attn.flash_attn_interface import (
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in
+ import flash_attn_2_cuda as flash_attn_gpu
+ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+
+The above exception was the direct cause of the following exception:
+
+Traceback (most recent call last):
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in
+ from transformers import Qwen2VLForConditionalGeneration
+ File "", line 1229, in _handle_fromlist
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
+ value = getattr(module, name)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
+ module = self._get_module(self._class_to_module[name])
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
+ raise RuntimeError(
+RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
+/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
+ return importlib.import_module("." + module_name, self.__name__)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
+ return _bootstrap._gcd_import(name[level:], package, level)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "", line 1204, in _gcd_import
+ File "", line 1176, in _find_and_load
+ File "", line 1147, in _find_and_load_unlocked
+ File "", line 690, in _load_unlocked
+ File "", line 940, in exec_module
+ File "", line 241, in _call_with_frames_removed
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in
+ from ...modeling_utils import PreTrainedModel
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in
+ from .integrations.flash_attention import flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in
+ from ..modeling_flash_attention_utils import _flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in
+ from flash_attn.flash_attn_interface import (
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in
+ import flash_attn_2_cuda as flash_attn_gpu
+ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+
+The above exception was the direct cause of the following exception:
+
+Traceback (most recent call last):
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in
+ from transformers import Qwen2VLForConditionalGeneration
+ File "", line 1229, in _handle_fromlist
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
+ value = getattr(module, name)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
+ module = self._get_module(self._class_to_module[name])
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
+ raise RuntimeError(
+RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
+/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
+ return importlib.import_module("." + module_name, self.__name__)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
+ return _bootstrap._gcd_import(name[level:], package, level)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "", line 1204, in _gcd_import
+ File "", line 1176, in _find_and_load
+ File "", line 1147, in _find_and_load_unlocked
+ File "", line 690, in _load_unlocked
+ File "", line 940, in exec_module
+ File "", line 241, in _call_with_frames_removed
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in
+ from ...modeling_utils import PreTrainedModel
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in
+ from .integrations.flash_attention import flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in
+ from ..modeling_flash_attention_utils import _flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in
+ from flash_attn.flash_attn_interface import (
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in
+ import flash_attn_2_cuda as flash_attn_gpu
+ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+
+The above exception was the direct cause of the following exception:
+
+Traceback (most recent call last):
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in
+ from transformers import Qwen2VLForConditionalGeneration
+ File "", line 1229, in _handle_fromlist
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
+ value = getattr(module, name)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
+ module = self._get_module(self._class_to_module[name])
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
+ raise RuntimeError(
+RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
+/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
+ return importlib.import_module("." + module_name, self.__name__)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
+ return _bootstrap._gcd_import(name[level:], package, level)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "", line 1204, in _gcd_import
+ File "", line 1176, in _find_and_load
+ File "", line 1147, in _find_and_load_unlocked
+ File "", line 690, in _load_unlocked
+ File "", line 940, in exec_module
+ File "", line 241, in _call_with_frames_removed
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in
+ from ...modeling_utils import PreTrainedModel
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in
+ from .integrations.flash_attention import flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in
+ from ..modeling_flash_attention_utils import _flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in
+ from flash_attn.flash_attn_interface import (
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in
+ import flash_attn_2_cuda as flash_attn_gpu
+ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+
+The above exception was the direct cause of the following exception:
+
+Traceback (most recent call last):
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in
+ from transformers import Qwen2VLForConditionalGeneration
+ File "", line 1229, in _handle_fromlist
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
+ value = getattr(module, name)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
+ module = self._get_module(self._class_to_module[name])
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
+ raise RuntimeError(
+RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
+/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
+ return importlib.import_module("." + module_name, self.__name__)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
+ return _bootstrap._gcd_import(name[level:], package, level)
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "", line 1204, in _gcd_import
+ File "", line 1176, in _find_and_load
+ File "", line 1147, in _find_and_load_unlocked
+ File "", line 690, in _load_unlocked
+ File "", line 940, in exec_module
+ File "", line 241, in _call_with_frames_removed
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in
+ from ...modeling_utils import PreTrainedModel
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in
+ from .integrations.flash_attention import flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in
+ from ..modeling_flash_attention_utils import _flash_attention_forward
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in
+ from flash_attn.flash_attn_interface import (
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in
+ import flash_attn_2_cuda as flash_attn_gpu
+ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+
+The above exception was the direct cause of the following exception:
+
+Traceback (most recent call last):
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in
+ from transformers import Qwen2VLForConditionalGeneration
+ File "", line 1229, in _handle_fromlist
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
+ value = getattr(module, name)
+ ^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
+ module = self._get_module(self._class_to_module[name])
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
+ raise RuntimeError(
+RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
+/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
+W0624 15:42:40.776000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019035 closing signal SIGTERM
+W0624 15:42:40.776000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019036 closing signal SIGTERM
+W0624 15:42:40.777000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019037 closing signal SIGTERM
+W0624 15:42:40.778000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019038 closing signal SIGTERM
+W0624 15:42:40.779000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019040 closing signal SIGTERM
+E0624 15:42:41.558000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 4 (pid: 1019039) of binary: /root/miniconda3/envs/video-r1-35/bin/python3.11
+Traceback (most recent call last):
+ File "/root/miniconda3/envs/video-r1-35/bin/torchrun", line 8, in
+ sys.exit(main())
+ ^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
+ return f(*args, **kwargs)
+ ^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/run.py", line 919, in main
+ run(args)
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/run.py", line 910, in run
+ elastic_launch(
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
+ return launch_agent(self._config, self._entrypoint, list(args))
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
+ raise ChildFailedError(
+torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
+============================================================
+src/open_r1/grpo-cot-LLMEval.py FAILED
+------------------------------------------------------------
+Failures:
+
+------------------------------------------------------------
+Root Cause (first observed failure):
+[0]:
+ time : 2025-06-24_15:42:40
+ host : TENCENT64.site
+ rank : 4 (local_rank: 4)
+ exitcode : 1 (pid: 1019039)
+ error_file:
+ traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
+============================================================
diff --git a/src/r1-v/src/open_r1/trainer/grpo_trainer.py b/src/r1-v/src/open_r1/trainer/grpo_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..78dadfb24fb4f6834034fc135bb23eaf78c49dbe
--- /dev/null
+++ b/src/r1-v/src/open_r1/trainer/grpo_trainer.py
@@ -0,0 +1,786 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import textwrap
+from collections import defaultdict
+from typing import Any, Callable, Optional, Union
+import random
+
+import torch
+import torch.utils.data
+import transformers
+from datasets import Dataset, IterableDataset
+from packaging import version
+from transformers import (
+ AriaForConditionalGeneration,
+ AriaProcessor,
+ AutoModelForCausalLM,
+ AutoModelForSequenceClassification,
+ AutoProcessor,
+ AutoTokenizer,
+ GenerationConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Qwen2VLForConditionalGeneration,
+ Qwen2_5_VLForConditionalGeneration,
+ Trainer,
+ TrainerCallback,
+ is_wandb_available,
+)
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+from transformers.utils import is_peft_available
+
+from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
+from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
+from trl.trainer.grpo_config import GRPOConfig
+from trl.trainer.utils import generate_model_card, get_comet_experiment_url
+
+from qwen_vl_utils import process_vision_info
+
+import copy
+
+
+if is_peft_available():
+ from peft import PeftConfig, get_peft_model
+
+if is_wandb_available():
+ import wandb
+
+
+# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
+# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
+RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
+
+
+class Qwen2VLGRPOTrainer(Trainer):
+ """
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
+
+ Example:
+
+ ```python
+ from datasets import load_dataset
+ from trl import GRPOTrainer
+
+ dataset = load_dataset("trl-lib/tldr", split="train")
+
+ trainer = GRPOTrainer(
+ model="Qwen/Qwen2-0.5B-Instruct",
+ reward_funcs="weqweasdas/RM-Gemma-2B",
+ train_dataset=dataset,
+ )
+
+ trainer.train()
+ ```
+
+ Args:
+ model (`Union[str, PreTrainedModel]`):
+ Model to be trained. Can be either:
+
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
+ a path to a *directory* containing model weights saved using
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
+ in `args.model_init_kwargs`.
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
+ functions with the prompts and completions and sum the rewards. Can be either:
+
+ - A single reward function, such as:
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
+ path to a *directory* containing model weights saved using
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
+ keyword arguments in `args.model_init_kwargs`.
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
+ - A custom reward function: The function is provided with the prompts and the generated completions,
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
+ [Using a custom reward function](#using-a-custom-reward-function).
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
+ Configuration for this trainer. If `None`, a default configuration is used.
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
+ ignored. The format of the samples can be either:
+
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
+ and content).
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
+
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
+ the corresponding entries in `reward_processing_classes` are ignored.
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
+
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
+ method.
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
+ """
+
+ def __init__(
+ self,
+ model: Union[str, PreTrainedModel],
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
+ args: GRPOConfig = None,
+ script_args = None,
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+ peft_config: Optional["PeftConfig"] = None,
+ max_pixels: Optional[int] = 12845056,
+ min_pixels: Optional[int] = 3136,
+ attn_implementation: str = "flash_attention_2",
+ ):
+ # Args
+ if args is None:
+ model_name = model if isinstance(model, str) else model.config._name_or_path
+ model_name = model_name.split("/")[-1]
+ args = GRPOConfig(f"{model_name}-GRPO")
+
+
+ # Models
+ # Trained model
+ model_init_kwargs = args.model_init_kwargs or {}
+ model_init_kwargs["attn_implementation"] = attn_implementation
+ if isinstance(model, str):
+ model_id = model
+ torch_dtype = model_init_kwargs.get("torch_dtype")
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
+ pass # torch_dtype is already a torch.dtype or "auto" or None
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
+ torch_dtype = getattr(torch, torch_dtype)
+ model_init_kwargs["torch_dtype"] = torch_dtype
+ else:
+ raise ValueError(
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
+ )
+ # Disable caching if gradient checkpointing is enabled (not supported)
+ model_init_kwargs["use_cache"] = (
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
+ )
+ if "Qwen2-VL" in model_id:
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ elif "Qwen2.5-VL" in model_id:
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ elif "Aria" in model_id:
+ model_init_kwargs.pop("use_cache")
+ model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ else:
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ # model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ else:
+ model_id = model.config._name_or_path
+ if args.model_init_kwargs is not None:
+ raise ValueError(
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
+ "This argument can only be used when the `model` argument is a string."
+ )
+
+ if peft_config is not None:
+ model = get_peft_model(model, peft_config)
+
+ #self.ref_model = None
+ # Reference model
+ if is_deepspeed_zero3_enabled():
+ if "Qwen2-VL" in model_id:
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
+ elif "Qwen2.5-VL" in model_id:
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
+ elif "Aria" in model_id:
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
+ else:
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
+ # self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
+ elif peft_config is None:
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
+ self.ref_model = create_reference_model(model)
+ else:
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
+ # to revert to the initial model.
+ self.ref_model = None
+
+ # Processing class
+ if processing_class is None:
+ if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id or True:
+ processing_class = AutoProcessor.from_pretrained(model_id)
+ pad_token_id = processing_class.tokenizer.pad_token_id
+ processing_class.pad_token_id = pad_token_id
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
+ if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
+ processing_class.image_processor.max_pixels = max_pixels
+ processing_class.image_processor.min_pixels = min_pixels
+ else:
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
+ pad_token_id = processing_class.pad_token_id
+
+ # Reward functions
+ if not isinstance(reward_funcs, list):
+ reward_funcs = [reward_funcs]
+ for i, reward_func in enumerate(reward_funcs):
+ if isinstance(reward_func, str):
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
+ reward_func, num_labels=1, **model_init_kwargs
+ )
+ self.reward_funcs = reward_funcs
+
+ # Reward processing class
+ if reward_processing_classes is None:
+ reward_processing_classes = [None] * len(reward_funcs)
+ elif not isinstance(reward_processing_classes, list):
+ reward_processing_classes = [reward_processing_classes]
+ else:
+ if len(reward_processing_classes) != len(reward_funcs):
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
+
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
+ if isinstance(reward_func, PreTrainedModel):
+ if reward_processing_class is None:
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
+ if reward_processing_class.pad_token_id is None:
+ reward_processing_class.pad_token = reward_processing_class.eos_token
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
+ reward_processing_classes[i] = reward_processing_class
+ self.reward_processing_classes = reward_processing_classes
+
+ # Data collator
+ def data_collator(features): # No data collation is needed in GRPO
+ return features
+
+ # Training arguments
+ self.max_prompt_length = args.max_prompt_length
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
+ self.num_generations = args.num_generations # = G in the GRPO paper
+ self.temporal = script_args.temporal
+ self.generation_config = GenerationConfig(
+ max_new_tokens=self.max_completion_length,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=self.num_generations,
+ pad_token_id=pad_token_id,
+ )
+ self.shuffled_num_generations = self.num_generations // 2
+ self.shuffled_generation_config = GenerationConfig(
+ max_new_tokens=self.max_completion_length,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=self.shuffled_num_generations,
+ pad_token_id=pad_token_id,
+ )
+
+ self.dummy_generation_config = GenerationConfig(
+ max_new_tokens=1,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=1,
+ pad_token_id=pad_token_id,
+ )
+ self.len_control = script_args.len_control
+ self.beta = args.beta
+
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
+ # This acts as a flag to indicate that the warning has already been issued.
+ model.warnings_issued["estimate_tokens"] = True
+
+ # Initialize the metrics
+ self._metrics = defaultdict(list)
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ )
+
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
+ # self.model_accepts_loss_kwargs to False to enable scaling.
+ self.model_accepts_loss_kwargs = False
+
+ if self.ref_model is not None:
+ if self.is_deepspeed_enabled:
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
+
+ def _set_signature_columns_if_needed(self):
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
+ if self._signature_columns is None:
+ self._signature_columns = ["prompt"]
+
+
+ # Get the per-token log probabilities for the completions for the model and the reference model
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
+ # import pdb
+ # pdb.set_trace()
+ logits = model(input_ids, **kwargs).logits
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
+ per_token_logps = []
+ for logits_row, input_ids_row in zip(logits, input_ids):
+ log_probs = logits_row.log_softmax(dim=-1)
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
+ per_token_logps.append(token_log_prob)
+ return torch.stack(per_token_logps)
+
+ def remove_none_from_data(self, data):
+ for entry in data:
+ if "content" in entry and isinstance(entry["content"], list):
+ for sub_entry in entry["content"]:
+ if isinstance(sub_entry, dict):
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
+ for k in keys_to_remove:
+ del sub_entry[k]
+ return data
+
+
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
+ return inputs
+
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ if return_outputs:
+ raise ValueError("The GRPOTrainer does not support returning outputs")
+
+
+
+ prompts = [x["prompt"] for x in inputs]
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
+
+
+
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
+
+ input_copy = self.remove_none_from_data(input_copy)
+
+ if inputs[0]['data_type'] == 'image':
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
+ elif inputs[0]['data_type'] == 'video':
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
+
+ try:
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
+ except Exception as e:
+ print(f"process_vision_info error, using fixed data, {e}")
+ if inputs[0]['data_type'] == 'image':
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + '/Math/Multimath-300k/17ff4c7d14c388134de02381b1fc2824.png'
+ elif inputs[0]['data_type'] == 'video':
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + '/LLaVA-Video-178K/liwei_youtube_videos/videos/youtube_video_2024/ytb_7nRmsEw7nsE.mp4'
+
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
+
+
+ prompt_inputs = self.processing_class(
+ text=copy.deepcopy(prompts_text),
+ images=image_inputs,
+ videos=video_inputs,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ add_special_tokens=False,
+ )
+
+
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
+
+
+ # fix prompt_inputs["input_ids"] length issue
+ if self.max_prompt_length is not None:
+ prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
+ prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
+
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
+
+
+ if self.max_prompt_length is not None:
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
+
+ if self.temporal and video_inputs:
+ indices = torch.randperm(video_inputs[0].size(0))
+ shuffled_video_inputs = [video_inputs[0][indices]]
+ shuffled_prompt_inputs = self.processing_class(
+ text=copy.deepcopy(prompts_text),
+ images=image_inputs,
+ videos=shuffled_video_inputs,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ add_special_tokens=False,
+ )
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
+ if self.max_prompt_length is not None:
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
+
+
+ # Generate completions
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
+ prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
+ prompt_length = prompt_ids.size(1)
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
+ completion_ids = prompt_completion_ids[:, prompt_length:]
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
+
+ if self.temporal:
+
+ if video_inputs:
+
+ shuffled_prompt_completion_ids = unwrapped_model.generate(**shuffled_prompt_inputs, generation_config=self.shuffled_generation_config)
+ shuffled_prompt_length = shuffled_prompt_ids.size(1)
+ shuffled_prompt_ids = shuffled_prompt_completion_ids[:, :shuffled_prompt_length]
+ shuffled_completion_ids = shuffled_prompt_completion_ids[:, shuffled_prompt_length:]
+ shuffled_prompt_mask = prompt_mask.repeat_interleave(self.shuffled_num_generations, dim=0)
+
+ else:
+
+ shuffled_prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.dummy_generation_config)
+
+
+ # print('path:', input_copy[0]['content'][0][inputs[0]['data_type']])
+ # print('problem_id:', inputs[0]['problem_id'])
+ # print('prompt_length:', prompt_length)
+
+
+
+
+ # Mask everything after the first EOS token
+ is_eos = completion_ids == self.processing_class.eos_token_id
+ device = self.accelerator.device
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
+
+ # Concatenate prompt_mask with completion_mask for logit computation
+ # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
+ # pixel_values = prompt_inputs["pixel_values"].repeat(self.num_generations, 1)
+ # image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0)
+
+
+
+ prompt_inputs.pop("input_ids")
+ prompt_inputs.pop("attention_mask")
+
+ if inputs[0]['data_type'] == 'image':
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
+ # import pdb; pdb.set_trace()
+
+
+ if inputs[0]['data_type'] == 'video':
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
+ if 'second_per_grid_ts' in prompt_inputs:
+ del prompt_inputs["second_per_grid_ts"]
+ # prompt_inputs["second_per_grid_ts"] = torch.tensor(prompt_inputs["second_per_grid_ts"]).repeat(len(prompt_completion_ids), 1)
+
+
+
+
+ try:
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
+ except Exception as e:
+ print(f"Error computing per_token_logps: {e}. Setting output to zero.")
+ # per_token_logps = torch.tensor(0.0, device=prompt_completion_ids.device, requires_grad=True)
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids)
+
+ with torch.inference_mode():
+ try:
+ if self.ref_model is not None:
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
+ else:
+ with self.accelerator.unwrap_model(model).disable_adapter():
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
+ except Exception as e:
+ print(f"Error computing ref_per_token_logps: {e}. Setting output to zero.")
+ # ref_per_token_logps = torch.tensor(0.0, device=prompt_completion_ids.device)
+ with self.accelerator.unwrap_model(model).disable_adapter():
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids)
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
+
+ # Compute the KL divergence between the model and the reference model
+
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
+
+ if self.temporal and video_inputs:
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
+ if is_conversational(inputs[0]):
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
+
+ # Compute the rewards
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
+ for key in shuffled_reward_kwargs:
+ for example in inputs:
+ # Repeat each value in the column for `num_generations` times
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
+
+
+ # Decode the generated completions
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
+ if is_conversational(inputs[0]):
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
+
+ # Compute the rewards
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
+ for key in reward_kwargs:
+ for example in inputs:
+ # Repeat each value in the column for `num_generations` times
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
+
+
+
+
+ if self.temporal and video_inputs:
+ temporal_rewards_per_func = rewards_per_func.clone()
+
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
+
+ if acc_mean >= 0.8 * shuffled_acc_mean:
+ mask = temporal_rewards_per_func[:, 0] > 0.1
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
+ else:
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
+ else:
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
+
+ # Sum the rewards from all reward functions
+ if self.temporal and video_inputs:
+ rewards = temporal_rewards_per_func.sum(dim=1)
+ else:
+ rewards = rewards_per_func.sum(dim=1)
+
+
+ if self.len_control:
+ mem_rewards = [0] * self.num_generations
+ mask = rewards_per_func[:, 0] > 0.1
+ lenth_list = completion_mask.sum(1)
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
+ # if len(selected_indices) > 1:
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
+ # N = len(sorted_items)
+ # for rank, (idx, length) in enumerate(sorted_items):
+ # reward = 0.2 - 0.2 * (rank / N)
+ # rewards[idx] += reward
+ # mem_rewards[idx] = reward
+ # for idx in range(len(lenth_list)):
+ # if lenth_list[idx] >= 512:
+ # rewards[idx] -= 0.5
+
+ if len(selected_indices) > 1:
+ for idx in selected_indices:
+ if 320 <= lenth_list[idx] <= 512:
+ rewards[idx] += 0.2
+
+ # print(rewards)
+ # print(completion_mask.sum(1))
+
+ # Compute grouped-wise rewards
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
+
+ # Normalize the rewards to compute the advantages
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
+
+ # if self.len_control and len(selected_indices) == self.num_generations:
+ # for idx in range(len(rewards)):
+ # advantages[idx] += (mem_rewards[idx] - 0.2) * 2
+
+ # x - x.detach() allows for preserving gradients from x
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
+ # per_token_loss = -per_token_loss
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+
+
+ # import pdb
+ # pdb.set_trace()
+
+ # Log the metrics
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
+ self._metrics["completion_length"].append(completion_length)
+
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
+ else:
+ reward_func_name = reward_func.__name__
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
+
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
+
+ num_devices = gathered_rewards.size(0) // self.num_generations
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
+ wrong_ratio = wrong_devices.sum().item() / num_devices
+
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
+ correct_ratio = correct_devices.sum().item() / num_devices
+
+ self._metrics["all_wrong"].append(wrong_ratio)
+ self._metrics["all_correct"].append(correct_ratio)
+
+ if self.temporal:
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
+
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
+
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
+
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
+
+
+ return loss
+
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
+ logs = {**logs, **metrics}
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
+ super().log(logs, start_time)
+ else: # transformers<=4.46
+ super().log(logs)
+ self._metrics.clear()
+
+ def create_model_card(
+ self,
+ model_name: Optional[str] = None,
+ dataset_name: Optional[str] = None,
+ tags: Union[str, list[str], None] = None,
+ ):
+ """
+ Creates a draft of a model card using the information available to the `Trainer`.
+
+ Args:
+ model_name (`str` or `None`, *optional*, defaults to `None`):
+ Name of the model.
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
+ Name of the dataset used for training.
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
+ Tags to be associated with the model card.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
+ base_model = self.model.config._name_or_path
+ else:
+ base_model = None
+
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
+ citation = textwrap.dedent(
+ """\
+ @article{zhihong2024deepseekmath,
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
+ year = 2024,
+ eprint = {arXiv:2402.03300},
+ """
+ )
+
+ model_card = generate_model_card(
+ base_model=base_model,
+ model_name=model_name,
+ hub_model_id=self.hub_model_id,
+ dataset_name=dataset_name,
+ tags=tags,
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
+ comet_url=get_comet_experiment_url(),
+ trainer_name="GRPO",
+ trainer_citation=citation,
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
+ paper_id="2402.03300",
+ )
+
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
diff --git a/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified.py b/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d071f8e25b17f871ebc6e8e4c423592cd8c990
--- /dev/null
+++ b/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified.py
@@ -0,0 +1,1224 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import textwrap
+from collections import defaultdict
+from typing import Any, Callable, Optional, Union
+from accelerate.utils.other import is_compiled_module
+from accelerate.utils import broadcast_object_list, gather, gather_object
+import torch
+import torch.utils.data
+import transformers
+import warnings
+from unittest.mock import patch
+from datasets import Dataset, IterableDataset
+from packaging import version
+from transformers import (
+ AriaForConditionalGeneration,
+ AriaProcessor,
+ AutoModelForCausalLM,
+ AutoModelForSequenceClassification,
+ AutoProcessor,
+ AutoTokenizer,
+ GenerationConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Qwen2VLForConditionalGeneration,
+ Qwen2_5_VLForConditionalGeneration,
+ Trainer,
+ TrainerCallback,
+ is_wandb_available,
+)
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+from transformers.utils import is_peft_available
+
+from trl.data_utils import (
+ apply_chat_template,
+ is_conversational,
+ maybe_apply_chat_template,
+)
+from trl.import_utils import is_vllm_available
+
+from trl.models import (
+ create_reference_model,
+ prepare_deepspeed,
+ unwrap_model_for_generation,
+)
+from trl.trainer.grpo_config import GRPOConfig
+from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
+from trl import GRPOTrainer
+
+import copy
+
+if is_peft_available():
+ from peft import PeftConfig, get_peft_model
+
+if is_vllm_available():
+ from vllm import LLM, SamplingParams
+
+if is_wandb_available():
+ import wandb
+import torch.nn as nn
+from torch.utils.data import Sampler
+import gc
+from qwen_vl_utils import process_vision_info
+
+
+
+# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
+# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
+RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
+
+import re
+
+def extract_answer(text: str) -> str:
+ """
+ 1) Try the full … block.
+ 2) If that is missing, grab whatever follows the opening tag.
+ 3) Otherwise return the original text.
+ """
+ # ① normal case …
+ m = re.search(r'\s*(.*?)\s*', text, flags=re.DOTALL | re.IGNORECASE)
+ if m:
+ return m.group(1).strip()
+
+ # ② fallback …
+ m = re.search(r'\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
+ if m:
+ return m.group(1).strip()
+
+ # ③ nothing found
+ return text.strip()
+
+def extract_info(predict: str) -> Optional[str]:
+ """
+ Extracts the content of the … block from `predict`.
+ Returns the inner text (with leading/trailing whitespace stripped),
+ or None if no tag is found.
+ """
+ match = re.search(r"([\s\S]*?)", predict, re.DOTALL)
+ if not match:
+ return predict
+ return match.group(1).strip()
+
+
+
+
+class Qwen2VLGRPOVLLMTrainerModified(Trainer):
+ def __init__(
+ self,
+ model: Union[str, PreTrainedModel],
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
+ args: GRPOConfig = None,
+ script_args = None,
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
+ eval_dataset: Optional[
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
+ ] = None,
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
+ reward_processing_classes: Optional[
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
+ ] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
+ ] = (None, None),
+ peft_config: Optional["PeftConfig"] = None,
+ # qwen2-vl related params
+ max_pixels: Optional[int] = 12845056,
+ min_pixels: Optional[int] = 3136,
+ attn_implementation: str = "flash_attention_2",
+ ):
+
+ # Args
+ if args is None:
+ model_name = model if isinstance(model, str) else model.config._name_or_path
+ model_name = model_name.split("/")[-1]
+ args = GRPOConfig(f"{model_name}-GRPO")
+
+ # Models
+ # Trained model
+ model_init_kwargs = args.model_init_kwargs or {}
+ model_init_kwargs["attn_implementation"] = attn_implementation
+ if isinstance(model, str):
+ model_id = model
+ torch_dtype = model_init_kwargs.get("torch_dtype")
+ if (
+ isinstance(torch_dtype, torch.dtype)
+ or torch_dtype == "auto"
+ or torch_dtype is None
+ ):
+ pass # torch_dtype is already a torch.dtype or "auto" or None
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
+ torch_dtype = getattr(torch, torch_dtype)
+ model_init_kwargs["torch_dtype"] = torch_dtype
+ else:
+ raise ValueError(
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
+ )
+ # Disable caching if gradient checkpointing is enabled (not supported)
+ model_init_kwargs["use_cache"] = (
+ False
+ if args.gradient_checkpointing
+ else model_init_kwargs.get("use_cache")
+ )
+ if "Qwen2-VL" in model_id:
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ model, **model_init_kwargs
+ )
+ elif "Qwen2.5-VL" in model_id:
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model, **model_init_kwargs
+ )
+ elif "Aria" in model_id:
+ model_init_kwargs.pop("use_cache")
+ model = AriaForConditionalGeneration.from_pretrained(
+ model, **model_init_kwargs
+ )
+ else:
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ else:
+ model_id = model.config._name_or_path
+ if args.model_init_kwargs is not None:
+ raise ValueError(
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
+ "This argument can only be used when the `model` argument is a string."
+ )
+
+ if peft_config is not None:
+ model = get_peft_model(model, peft_config)
+
+ # Reference model
+ if is_deepspeed_zero3_enabled():
+ if "Qwen2-VL" in model_id:
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ elif "Qwen2.5-VL" in model_id:
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ elif "Aria" in model_id:
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ else:
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ elif peft_config is None:
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
+ self.ref_model = create_reference_model(model)
+ else:
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
+ # to revert to the initial model.
+ self.ref_model = None
+
+ # Processing class
+ # if processing_class is None:
+ # if "Qwen" in model_id or "Aria" in model_id:
+ # processing_class = AutoProcessor.from_pretrained(model_id)
+ # pad_token_id = processing_class.tokenizer.pad_token_id
+ # processing_class.pad_token_id = pad_token_id
+ # processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
+ # if "Qwen" in model_id:
+ # processing_class.image_processor.max_pixels = max_pixels
+ # processing_class.image_processor.min_pixels = min_pixels
+ # else:
+ # processing_class = AutoTokenizer.from_pretrained(
+ # model.config._name_or_path, padding_side="left"
+ # )
+ # pad_token_id = processing_class.pad_token_id
+
+ if processing_class is None:
+ # 1️⃣ First try to load whatever lives in the directory we were given.
+ # This succeeds if you previously did `processor.save_pretrained(output_dir)`.
+ try:
+ processing_class = AutoProcessor.from_pretrained(model_id)
+ pad_token_id = processing_class.tokenizer.pad_token_id
+ except (OSError, ValueError): # no processor files found
+ # 2️⃣ Fall back to inspecting the *model object* instead of the path.
+ is_vl_model = (
+ hasattr(model, "vision_tower") or # Qwen-VL, InternVL, etc.
+ getattr(model.config, "vision_config", None) is not None or
+ getattr(model.config, "image_vocab_size", None) is not None
+ )
+
+ if is_vl_model:
+ # Always use the *base* model name stored in the config.
+ base_name = model.config._name_or_path # e.g. "Qwen/Qwen2.5-VL-7B-Instruct"
+ processing_class = AutoProcessor.from_pretrained(base_name)
+ pad_token_id = processing_class.tokenizer.pad_token_id
+
+ # Optional Qwen-specific limits
+ if hasattr(processing_class, "image_processor"):
+ processing_class.image_processor.max_pixels = max_pixels
+ processing_class.image_processor.min_pixels = min_pixels
+ else:
+ # Pure text model → plain tokenizer
+ processing_class = AutoTokenizer.from_pretrained(
+ model.config._name_or_path, padding_side="left"
+ )
+ pad_token_id = processing_class.pad_token_id
+
+ # 3️⃣ Harmonise attributes the rest of the trainer expects
+ processing_class.pad_token_id = pad_token_id
+ if not hasattr(processing_class, "eos_token_id"):
+ processing_class.eos_token_id = pad_token_id
+
+ # Reward functions
+ if not isinstance(reward_funcs, list):
+ reward_funcs = [reward_funcs]
+ for i, reward_func in enumerate(reward_funcs):
+ if isinstance(reward_func, str):
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
+ reward_func, num_labels=1, **model_init_kwargs
+ )
+ self.reward_funcs = reward_funcs
+
+ # Reward processing class
+ if reward_processing_classes is None:
+ reward_processing_classes = [None] * len(reward_funcs)
+ elif not isinstance(reward_processing_classes, list):
+ reward_processing_classes = [reward_processing_classes]
+ else:
+ if len(reward_processing_classes) != len(reward_funcs):
+ raise ValueError(
+ "The number of reward processing classes must match the number of reward functions."
+ )
+
+ for i, (reward_processing_class, reward_func) in enumerate(
+ zip(reward_processing_classes, reward_funcs)
+ ):
+ if isinstance(reward_func, PreTrainedModel):
+ if reward_processing_class is None:
+ reward_processing_class = AutoTokenizer.from_pretrained(
+ reward_func.config._name_or_path
+ )
+ if reward_processing_class.pad_token_id is None:
+ reward_processing_class.pad_token = (
+ reward_processing_class.eos_token
+ )
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
+ reward_processing_classes[i] = reward_processing_class
+ self.reward_processing_classes = reward_processing_classes
+
+ # Data collator
+ def data_collator(features): # No data collation is needed in GRPO
+ return features
+
+ # Training arguments
+ self.max_prompt_length = args.max_prompt_length
+ self.max_completion_length = (
+ args.max_completion_length
+ ) # = |o_i| in the GRPO paper
+ self.num_generations = args.num_generations # = G in the GRPO paper
+ self.temporal = script_args.temporal
+ self.generation_config = GenerationConfig(
+ max_new_tokens=self.max_completion_length,
+ do_sample=True,
+ temperature=1, # HACK
+ num_return_sequences=self.num_generations,
+ pad_token_id=pad_token_id,
+ )
+ self.beta = args.beta
+
+ self.shuffled_num_generations = self.num_generations // 2
+ self.shuffled_generation_config = GenerationConfig(
+ max_new_tokens=self.max_completion_length,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=self.shuffled_num_generations,
+ pad_token_id=pad_token_id,
+ )
+
+ self.dummy_generation_config = GenerationConfig(
+ max_new_tokens=1,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=1,
+ pad_token_id=pad_token_id,
+ )
+ self.len_control = script_args.len_control
+ self.beta = args.beta
+
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
+ # This acts as a flag to indicate that the warning has already been issued.
+ model.warnings_issued["estimate_tokens"] = True
+
+ # Initialize the metrics
+ self._metrics = defaultdict(list)
+ self.use_vllm = args.use_vllm
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ )
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
+ # self.model_accepts_loss_kwargs to False to enable scaling.
+ self.model_accepts_loss_kwargs = False
+
+ if self.use_vllm:
+ if not is_vllm_available():
+ raise ImportError(
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
+ "`pip install vllm` to use it."
+ )
+
+ if self.accelerator.is_main_process:
+ vllm_device = self.args.vllm_device
+ if vllm_device == "auto":
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
+
+ # ──────────────────── NEW BEGIN ────────────────────────
+ # Accept a comma-separated list, e.g. "cuda:6,7"
+ # device_tokens = [tok.strip() for tok in vllm_device.split(",")]
+ # multi_gpu = len(device_tokens) > 1
+
+ # if multi_gpu:
+ # # keep only the numeric part ("cuda:6" -> "6")
+ # # physical_ids = [tok.split(":")[1] for tok in device_tokens]
+ # physical_ids = [tok.split(":")[-1] for tok in device_tokens]
+
+ # # Mask visibility *in this process only* (rank-0)
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(physical_ids)
+
+ # logical_device = "cuda" # vLLM sees them as 0,1,…
+ # tensor_parallel_size = len(physical_ids)
+ # else:
+ # logical_device = vllm_device # single id like "cuda:6"
+ # tensor_parallel_size = 1
+
+ # vllm_device = logical_device
+ # ──────────────────── NEW END ────────────────────────
+
+
+ # Check that the requested device is available
+ '''
+ The first if statement below is to guard vllm errors'''
+ # if (not multi_gpu) and vllm_device.startswith("cuda:"):
+ # gpu_idx = int(vllm_device.split(":")[1])
+ # if gpu_idx >= torch.cuda.device_count():
+ # raise ValueError(
+ # f"The requested device {vllm_device} is not available. "
+ # f"You only have {torch.cuda.device_count()} GPUs."
+ # )
+
+ # # ---------- overlap-with-training warning (skip for multi-GPU) ---------
+ # if (not multi_gpu) and vllm_device in {
+ # f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
+ # }:
+ # warnings.warn(
+ # f"The requested vLLM device {vllm_device} is also used for training. "
+ # "This may lead to unexpected behaviour."
+ # )
+ if (
+ vllm_device.split(":")[0] == "cuda"
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
+ ):
+ raise ValueError(
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
+ )
+ # Check that the requested device is not also used for training
+ if vllm_device in {
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
+ }:
+ warnings.warn(
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
+ "behavior. It is recommended to use a dedicated device for vLLM."
+ )
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
+ # setting (profiling_patch).
+ # world_size_patch = patch(
+ # "torch.distributed.get_world_size", return_value=1
+ # )
+
+ '''
+ Below is the cahnged code
+ '''
+ # world_size_patch = patch(
+ # "torch.distributed.get_world_size", return_value=tensor_parallel_size
+ # )
+
+ # profiling_patch = patch(
+ # "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
+ # return_value=None,
+ # )
+ '''Above is the changed code'''
+
+ world_size_patch = patch(
+ "torch.distributed.get_world_size", return_value=1
+ )
+ profiling_patch = patch(
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
+ return_value=None,
+ )
+
+ '''
+ Below changes
+ '''
+ with world_size_patch, profiling_patch:
+ # with profiling_patch:
+ print("vllm is running on: ", vllm_device)
+ from vllm.config import ParallelConfig
+ self.llm = LLM(
+ model=model.name_or_path,
+ device=vllm_device,
+ # tensor_parallel_size=tensor_parallel_size, # ← 1 or N
+ # parallel_config=ParallelConfig( # ← NEW
+ # tensor_parallel_size=tensor_parallel_size
+ # ),
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
+ dtype=torch.bfloat16,
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
+ # This is particularly useful here because we generate completions from the same prompts.
+ enable_prefix_caching=True,
+ enforce_eager=True,
+ mm_processor_kwargs=(
+ {
+ "max_pixels": max_pixels,
+ "min_pixels": min_pixels,
+ }
+ # if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
+ if False
+ else None
+ ),
+ max_model_len=args.max_prompt_length + args.max_completion_length,
+ )
+ self.sampling_params = SamplingParams(
+ temperature=1.0,
+ top_p=0.95,
+ max_tokens=self.max_completion_length,
+ )
+
+ # self.second_sampling_params = SamplingParams(
+ # n = 1, # one generation
+ # temperature = 0.5, # less squeezing
+ # top_p = 0.9, # nucleus filter
+ # # top_k = 50, # (alternative to top_p)
+ # min_tokens = 4, # force at least 4 tokens
+ # max_tokens = self.max_completion_length,
+ # )
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
+
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
+ # synchronize all processes after vLLM has been fully initialized.
+ self.accelerator.wait_for_everyone()
+ else:
+ raise ValueError(
+ "GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
+ )
+
+ if self.ref_model is not None:
+ if self.is_deepspeed_enabled:
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
+
+ def _set_signature_columns_if_needed(self):
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
+ if self._signature_columns is None:
+ self._signature_columns = ["prompt"]
+
+ # Get the per-token log probabilities for the completions for the model and the reference model
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
+ # import pdb
+ # pdb.set_trace()
+ logits = model(input_ids, **kwargs).logits
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
+ per_token_logps = []
+ for logits_row, input_ids_row in zip(logits, input_ids):
+ log_probs = logits_row.log_softmax(dim=-1)
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
+ per_token_logps.append(token_log_prob)
+ return torch.stack(per_token_logps)
+
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
+ def _prepare_inputs(
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
+ ) -> dict[str, Union[torch.Tensor, Any]]:
+ return inputs
+
+ def remove_none_from_data(self, data):
+ for entry in data:
+ if "content" in entry and isinstance(entry["content"], list):
+ for sub_entry in entry["content"]:
+ if isinstance(sub_entry, dict):
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
+ for k in keys_to_remove:
+ del sub_entry[k]
+ return data
+
+
+
+ def compute_loss(
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
+ ):
+ if return_outputs:
+ raise ValueError("The GRPOTrainer does not support returning outputs")
+ # Compute the per-token log probabilities for the model
+
+
+ device = self.accelerator.device
+ prompts = [x["prompt"] for x in inputs]
+ # images = [x["image"] for x in inputs]
+ prompts_text = [
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
+ for example in inputs
+ ]
+
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
+
+ input_copy = self.remove_none_from_data(input_copy)
+
+ data_type = inputs[0]['data_type']
+
+ if data_type == 'image':
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
+ elif data_type == 'video':
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
+
+
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
+
+
+ prompt_inputs = self.processing_class(
+ text=copy.deepcopy(prompts_text),
+ images=image_inputs,
+ videos=video_inputs,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ add_special_tokens=False,
+ )
+
+ mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
+
+ if self.max_prompt_length is not None:
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
+
+
+ if self.temporal:
+ if video_inputs:
+ indices = torch.randperm(video_inputs[0].size(0))
+ shuffled_video_inputs = [video_inputs[0][indices]]
+ shuffled_prompt_inputs = self.processing_class(
+ text=copy.deepcopy(prompts_text),
+ images=image_inputs,
+ videos=shuffled_video_inputs,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ add_special_tokens=False,
+ )
+ shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
+ if self.max_prompt_length is not None:
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
+ else:
+ shuffled_mm_data = [None]
+
+
+
+ if self.args.use_vllm:
+ # First, have main process load weights if needed
+ if self.state.global_step != self._last_loaded_step:
+ with unwrap_model_for_generation(
+ self.model,
+ self.accelerator,
+ gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
+ ) as unwrapped_model:
+ if is_compiled_module(unwrapped_model):
+ state_dict = unwrapped_model._orig_mod.state_dict()
+ else:
+ state_dict = unwrapped_model.state_dict()
+ if self.accelerator.is_main_process:
+ llm_model = (
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
+ )
+ # import pdb
+ # pdb.set_trace()
+ llm_model.load_weights(state_dict.items())
+ self._last_loaded_step = self.state.global_step
+
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
+ all_prompts_text = gather_object(prompts_text)
+ all_mm_data = gather_object(mm_data)
+ # group into pairs
+ all_multimodal_inputs = []
+
+ if self.temporal:
+ shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
+ shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
+ shuffled_all_multimodal_inputs = []
+
+ # 2. Refer to TobiasLee's implementation suggestions
+ # this is a better implementation for vLLM sampling.
+ for prompt, mm_item in zip(all_prompts_text, all_mm_data):
+ all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ for mm_item in shuffled_all_mm_data:
+ shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
+
+ # Create sampling params with num_generations
+ if self.accelerator.is_main_process:
+ # Clone to avoid modifying original params
+ sampling_params = copy.deepcopy(self.sampling_params)
+ sampling_params.n = self.num_generations
+ # Single generate call with all prompts
+ if self.accelerator.is_main_process:
+ outputs = self.llm.generate(
+ all_multimodal_inputs,
+ sampling_params=sampling_params,
+ use_tqdm=False,
+ )
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
+ completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ # Clone to avoid modifying original params
+ shuffled_sampling_params = copy.deepcopy(self.sampling_params)
+ shuffled_sampling_params.n = self.num_generations // 2
+ # Single generate call with all prompts
+ if self.accelerator.is_main_process:
+ shuffled_outputs = self.llm.generate(
+ shuffled_all_multimodal_inputs,
+ sampling_params=shuffled_sampling_params,
+ use_tqdm=False,
+ )
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
+ shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
+
+
+ else:
+ completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
+
+
+ # broadcast and slice
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
+ process_slice = slice(
+ self.accelerator.process_index * len(prompts) * self.num_generations,
+ (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
+ )
+ completion_ids = completion_ids[process_slice]
+
+ # Pad the completions, and concatenate them with the prompts
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
+ completion_ids = pad(
+ completion_ids, padding_value=self.processing_class.pad_token_id
+ )
+ prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+
+ prompt_length = prompt_ids.size(1)
+
+ # print('prompt_length:', prompt_length)
+
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
+ completion_ids = prompt_completion_ids[:, prompt_length:]
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
+
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ # broadcast and slice
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
+ process_id_list = []
+ for mm_item in shuffled_all_mm_data:
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
+
+ if video_inputs:
+ cur_shuffled_completion_ids = []
+ for i in range(len(process_id_list)):
+ if self.accelerator.process_index == process_id_list[i]:
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
+
+ # Pad the completions, and concatenate them with the prompts
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
+ cur_shuffled_completion_ids = pad(
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
+ )
+ shuffled_completion_ids = cur_shuffled_completion_ids
+
+
+ else:
+ raise ValueError("Only vLLM generation is supported in this version ")
+
+ # below are the same with yifan's code
+ # Mask everything after the first EOS token
+ is_eos = completion_ids == self.processing_class.eos_token_id
+ device = self.accelerator.device
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
+
+
+
+ prompt_inputs.pop("input_ids")
+ prompt_inputs.pop("attention_mask")
+
+ if data_type == 'image':
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
+ # import pdb; pdb.set_trace()
+
+
+ if data_type == 'video':
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
+ if 'second_per_grid_ts' in prompt_inputs:
+ del prompt_inputs["second_per_grid_ts"]
+
+ # import pdb
+ # pdb.set_trace()
+
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ with torch.inference_mode():
+ if self.ref_model is not None:
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
+ else:
+ with self.accelerator.unwrap_model(model).disable_adapter():
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
+
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.temporal and video_inputs:
+
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
+ if is_conversational(inputs[0]):
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
+
+ # Compute the rewards
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
+ for key in shuffled_reward_kwargs:
+ for example in inputs:
+ # Repeat each value in the column for `num_generations` times
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
+
+
+
+ # Decode the generated completions
+ completions = self.processing_class.batch_decode(
+ completion_ids, skip_special_tokens=True
+ )
+
+ '''
+ Below code is added for second round generation
+ '''
+ # second_stage_prompts_text = completions # ← list[str]
+ # curr_problem = example['problem']
+ # print('curr problem is: ', curr_problem)
+ # problem_key = "problem" if "problem" in inputs[0] else "question"
+ # # ─── For each sample in the batch, repeat its problem text num_generations times
+ # problems_aligned = [
+ # str(ex[problem_key])
+ # for ex in inputs
+ # for _ in range(self.num_generations)
+ # ]
+
+ # 1️⃣ descriptions extracted from first-round completions
+ second_stage_prompts_descriptions = [
+ str(extract_info(c) or "") # len = B * n_gen
+ for c in completions
+ ]
+
+ # 2️⃣ obtain + template the verify prompt for every sample,
+ # then repeat it n_gen times to align with descriptions
+ verify_templates = []
+ for ex in inputs: # B samples
+ tmpl = ex["verify_prompt"] # may be dict or str
+
+ # ▸ if it's still a dict, wrap it NOW
+ if not isinstance(tmpl, str):
+ tmpl = maybe_apply_chat_template(
+ tmpl, # conversation-dict
+ self.processing_class
+ )["prompt"] # templated string
+
+ verify_templates.extend([tmpl] * self.num_generations)
+
+ # 3️⃣ fill the {description} or {Description} slot
+ def fill_template(tmpl: str, desc: str) -> str:
+ # Replace both spelling variants and avoid all other {…} in the string
+ return (tmpl
+ .replace("{Description}", desc)
+ .replace("{description}", desc))
+
+ second_stage_chat_prompts = [
+ fill_template(tmpl, desc)
+ for tmpl, desc in zip(verify_templates, second_stage_prompts_descriptions)
+ ]
+
+ # 4️⃣ ready for vLLM – already chat-templated
+ all_second_prompts_text = gather_object(second_stage_chat_prompts)
+ second_multimodal_inputs = [
+ {"prompt": p, "multi_modal_data": {}} # text-only; no vision inputs
+ for p in all_second_prompts_text
+ ]
+ second_stage_prompts_text = second_stage_chat_prompts
+
+ # # print("problems_aligned types:", [type(p).__name__ for p in problems_aligned])
+ # # print("second_stage_prompts_descriptions types:", [type(s).__name__ for s in second_stage_prompts_descriptions])
+ # # second_stage_prompts_text = [ABS_Verify_Prompt.replace('{text}', second_stage_prompts_descriptions[count_index]).replace('{question}', problems_aligned[count_index]) for count_index in range(len(second_stage_prompts_descriptions))]
+ # second_stage_prompts_text = [ABS_Verify_Prompt.format(second_stage_prompts_descriptions[count_index], problems_aligned[count_index].replace('', '')) for count_index in range(len(second_stage_prompts_descriptions))]
+ # # print('Problems aligned: ', problems_aligned)
+ # # print('-'*10)
+ # # import time
+ # # time.sleep(40)
+
+ # second_stage_chat_prompts = [
+ # maybe_apply_chat_template( # ← your helper
+ # {
+ # "prompt": [
+ # {
+ # "role": "user",
+ # "content": [
+ # {"type": "text", "text": p} # ONLY text this round
+ # ],
+ # },
+ # ],
+ # },
+ # self.processing_class,
+ # )["prompt"] # returns the templated string
+ # for p in second_stage_prompts_text
+ # ]
+
+
+
+ # # 2️⃣ Tokenise / pad just like before (no image- or video-data)
+ # second_stage_inputs = self.processing_class(
+ # text=second_stage_prompts_text,
+ # images=None,
+ # videos=None,
+ # return_tensors="pt",
+ # padding=True,
+ # padding_side="left",
+ # add_special_tokens=False,
+ # )
+ # second_stage_inputs = super()._prepare_inputs(second_stage_inputs)
+
+ # # 3️⃣ Build the vLLM input objects (empty multi-modal dict)
+ # # all_second_prompts_text = gather_object(second_stage_prompts_text)
+ # all_second_prompts_text = gather_object(second_stage_chat_prompts)
+ # second_multimodal_inputs = [
+ # {"prompt": p, "multi_modal_data": {}} # no vision inputs this round
+ # for p in all_second_prompts_text
+ # ]
+
+ # print('Second stage prompt input: ')
+ # print(second_multimodal_inputs[0])
+ # print('*'*10)
+ # import time
+ # print('Examining output')
+ # time.sleep(10)
+
+ # 4️⃣ vLLM generation (same sampling params, same number of gens)
+ if self.accelerator.is_main_process:
+ second_sampling_params = copy.deepcopy(self.sampling_params)
+ # second_sampling_params = copy.deepcopy(self.second_sampling_params)
+ second_sampling_params.n = self.num_generations
+ second_outputs = self.llm.generate(
+ second_multimodal_inputs,
+ sampling_params=second_sampling_params,
+ use_tqdm=False,
+ )
+ second_completion_ids = [
+ out.token_ids
+ for completion in second_outputs
+ for out in completion.outputs
+ ]
+ else:
+ second_completion_ids = [None] * len(second_multimodal_inputs) * self.num_generations
+
+ # 5️⃣ Broadcast / slice back to every process
+ second_completion_ids = broadcast_object_list(second_completion_ids, from_process=0)
+ process_slice2 = slice(
+ self.accelerator.process_index * len(second_stage_prompts_text) * self.num_generations,
+ (self.accelerator.process_index + 1) * len(second_stage_prompts_text) * self.num_generations,
+ )
+ second_completion_ids = second_completion_ids[process_slice2]
+
+ # 6️⃣ Pad & move to device
+ second_completion_ids = [
+ torch.tensor(ids, device=device) for ids in second_completion_ids
+ ]
+ second_completion_ids = pad(
+ second_completion_ids, padding_value=self.processing_class.pad_token_id
+ )
+
+ # 7️⃣ Decode the second-round generations (list[str])
+ second_completions = self.processing_class.batch_decode(
+ second_completion_ids, skip_special_tokens=True
+ )
+
+ # print('Second completions: ')
+ # print(second_completions[0])
+ # print('*'*10)
+ # time.sleep(40)
+
+ # 8️⃣ (Optional) wrap conversationally, log, or feed into further
+ # reward computation just like the first-round completions.
+ # For example:
+ # if is_conversational(inputs[0]):
+ # second_completions = [
+ # [{"role": "assistant", "content": c}] for c in second_completions
+ # ]
+
+ second_round_info = {
+ "second_prompts": second_stage_prompts_text, # list[str]
+ "second_completions": second_completions, # list[str]
+ }
+ '''
+ Above code is added for second round generation
+ '''
+
+
+
+
+ if is_conversational(inputs[0]):
+ completions = [
+ [{"role": "assistant", "content": completion}]
+ for completion in completions
+ ]
+
+ # Compute the rewards
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
+ rewards_per_func = torch.zeros(
+ len(prompts), len(self.reward_funcs), device=device
+ )
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ reward_kwargs = {
+ key: []
+ for key in inputs[0].keys()
+ if key not in ["prompt", "completion"]
+ }
+
+ # reward_kwargs.update(second_round_info)
+
+ for key in reward_kwargs:
+ for example in inputs:
+ # Repeat each value in the column for `num_generations` times
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
+
+ reward_kwargs["second_prompts"] = second_stage_prompts_text # len = len(completions)
+ reward_kwargs["second_completions"] = second_completions
+
+ output_reward_func = reward_func(
+ prompts=prompts, completions=completions, **reward_kwargs
+ )
+ rewards_per_func[:, i] = torch.tensor(
+ output_reward_func, dtype=torch.float32, device=device
+ )
+
+
+ # rewards_per_func = gather(rewards_per_func)
+ # # Sum the rewards from all reward functions
+ # rewards = rewards_per_func.sum(dim=1)
+
+ # process_slice = slice(
+ # self.accelerator.process_index * len(prompts),
+ # (self.accelerator.process_index + 1) * len(prompts),
+ # )
+
+ # rewards = rewards[process_slice]
+
+
+
+ if self.temporal and video_inputs:
+ temporal_rewards_per_func = rewards_per_func.clone()
+
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
+
+ if acc_mean >= 0.8 * shuffled_acc_mean:
+ mask = temporal_rewards_per_func[:, 0] > 0.1
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
+ else:
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
+ else:
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
+
+ # Sum the rewards from all reward functions
+ if self.temporal and video_inputs:
+ rewards = temporal_rewards_per_func.sum(dim=1)
+ else:
+ rewards = rewards_per_func.sum(dim=1)
+
+ if self.len_control:
+ mem_rewards = [0] * self.num_generations
+ mask = rewards_per_func[:, 0] > 0.1
+ lenth_list = completion_mask.sum(1)
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
+ # if len(selected_indices) > 1:
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
+ # N = len(sorted_items)
+ # for rank, (idx, length) in enumerate(sorted_items):
+ # reward = 0.2 - 0.2 * (rank / N)
+ # rewards[idx] += reward
+ # mem_rewards[idx] = reward
+ # for idx in range(len(lenth_list)):
+ # if lenth_list[idx] >= 512:
+ # rewards[idx] -= 0.5
+
+ if len(selected_indices) > 1:
+ for idx in selected_indices:
+ if 320 <= lenth_list[idx] <= 1600:
+ rewards[idx] += 0.2
+
+ # print(rewards)
+ # print(completion_mask.sum(1))
+
+ # Compute grouped-wise rewards
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
+
+ # Normalize the rewards to compute the advantages
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
+
+ # x - x.detach() allows for preserving gradients from x
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
+ # per_token_loss = -per_token_loss
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+
+
+ # import pdb
+ # pdb.set_trace()
+
+ # Log the metrics
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
+ self._metrics["completion_length"].append(completion_length)
+
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
+ else:
+ reward_func_name = reward_func.__name__
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
+
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
+
+ num_devices = gathered_rewards.size(0) // self.num_generations
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
+ wrong_ratio = wrong_devices.sum().item() / num_devices
+
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
+ correct_ratio = correct_devices.sum().item() / num_devices
+
+ self._metrics["all_wrong"].append(wrong_ratio)
+ self._metrics["all_correct"].append(correct_ratio)
+
+ if self.temporal:
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
+
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
+
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
+
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
+
+
+ return loss
+
+
+
+
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
+
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
+ if next(iter(logs.keys())).startswith("eval_"):
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
+
+ logs = {**logs, **metrics}
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
+ super().log(logs, start_time)
+ else: # transformers<=4.46
+ super().log(logs)
+ self._metrics.clear()
\ No newline at end of file
diff --git a/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_selfConst.py b/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_selfConst.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5ef5aa21378055878ca7125fd0c505fe303ee4
--- /dev/null
+++ b/src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_selfConst.py
@@ -0,0 +1,1186 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import textwrap
+from collections import defaultdict
+from typing import Any, Callable, Optional, Union
+from accelerate.utils.other import is_compiled_module
+from accelerate.utils import broadcast_object_list, gather, gather_object
+import torch
+import torch.utils.data
+import transformers
+import warnings
+from unittest.mock import patch
+from datasets import Dataset, IterableDataset
+from packaging import version
+from transformers import (
+ AriaForConditionalGeneration,
+ AriaProcessor,
+ AutoModelForCausalLM,
+ AutoModelForSequenceClassification,
+ AutoProcessor,
+ AutoTokenizer,
+ GenerationConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Qwen2VLForConditionalGeneration,
+ Qwen2_5_VLForConditionalGeneration,
+ Trainer,
+ TrainerCallback,
+ is_wandb_available,
+)
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+from transformers.utils import is_peft_available
+
+from trl.data_utils import (
+ apply_chat_template,
+ is_conversational,
+ maybe_apply_chat_template,
+)
+from trl.import_utils import is_vllm_available
+
+from trl.models import (
+ create_reference_model,
+ prepare_deepspeed,
+ unwrap_model_for_generation,
+)
+from trl.trainer.grpo_config import GRPOConfig
+from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
+from trl import GRPOTrainer
+
+import copy
+
+if is_peft_available():
+ from peft import PeftConfig, get_peft_model
+
+if is_vllm_available():
+ from vllm import LLM, SamplingParams
+
+if is_wandb_available():
+ import wandb
+import torch.nn as nn
+from torch.utils.data import Sampler
+import gc
+from qwen_vl_utils import process_vision_info
+
+import torch, deepspeed
+from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
+
+# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
+# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
+RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
+
+import re
+
+def extract_answer(text: str) -> str:
+ """
+ 1) Try the full … block.
+ 2) If that is missing, grab whatever follows the opening tag.
+ 3) Otherwise return the original text.
+ """
+ # ① normal case …
+ m = re.search(r'\s*(.*?)\s*', text, flags=re.DOTALL | re.IGNORECASE)
+ if m:
+ return m.group(1).strip()
+
+ # ② fallback …
+ m = re.search(r'\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
+ if m:
+ return m.group(1).strip()
+
+ # ③ nothing found
+ return text.strip()
+
+def extract_info(predict: str) -> Optional[str]:
+ """
+ Extracts the content of the … block from `predict`.
+ Returns the inner text (with leading/trailing whitespace stripped),
+ or None if no tag is found.
+ """
+ match = re.search(r"([\s\S]*?)", predict, re.DOTALL)
+ if not match:
+ return predict
+ return match.group(1).strip()
+
+
+class DSRunner:
+ def __init__(self, model_id: str, gpu_id: int = 7, dtype=torch.float16):
+ self.device = torch.device(f"cuda:{gpu_id}")
+ torch.cuda.set_device(self.device)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_id, padding_side="left", trust_remote_code=True)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ base = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=dtype,
+ trust_remote_code=True,
+ ).to(self.device).eval()
+
+ self.model = deepspeed.init_inference(
+ base,
+ mp_size=1,
+ dtype=dtype,
+ replace_method="auto",
+ replace_with_kernel_inject=True,
+ ).module
+
+ # ↳ returns **len(prompts) * n** strings, grouped per-prompt
+ def generate(self, prompts, *, n=1, max_new_tokens=32,
+ temperature=0.0, top_p=1.0):
+ cfg = GenerationConfig(
+ do_sample=temperature > 0,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ num_return_sequences=n,
+ )
+
+ enc = self.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding=True,
+ truncation=False
+ ).to(self.device)
+
+ with torch.no_grad():
+ out = self.model.generate(**enc, generation_config=cfg)
+
+ # split into groups of `n` per original prompt
+ out = out.view(len(prompts), n, -1)
+ completions = []
+ for prompt, rows in zip(prompts, out):
+ full = self.tokenizer.batch_decode(rows, skip_special_tokens=True)
+ completions.extend([s[len(prompt):].strip() for s in full])
+ return completions
+
+
+class Qwen2VLGRPOVLLMTrainerSelfConst(Trainer):
+ def __init__(
+ self,
+ model: Union[str, PreTrainedModel],
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
+ args: GRPOConfig = None,
+ script_args = None,
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
+ eval_dataset: Optional[
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
+ ] = None,
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
+ reward_processing_classes: Optional[
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
+ ] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
+ ] = (None, None),
+ peft_config: Optional["PeftConfig"] = None,
+ # qwen2-vl related params
+ max_pixels: Optional[int] = 12845056,
+ min_pixels: Optional[int] = 3136,
+ attn_implementation: str = "flash_attention_2",
+ ):
+
+ # Args
+ if args is None:
+ model_name = model if isinstance(model, str) else model.config._name_or_path
+ model_name = model_name.split("/")[-1]
+ args = GRPOConfig(f"{model_name}-GRPO")
+
+ # Models
+ # Trained model
+ model_init_kwargs = args.model_init_kwargs or {}
+ model_init_kwargs["attn_implementation"] = attn_implementation
+ if isinstance(model, str):
+ model_id = model
+ torch_dtype = model_init_kwargs.get("torch_dtype")
+ if (
+ isinstance(torch_dtype, torch.dtype)
+ or torch_dtype == "auto"
+ or torch_dtype is None
+ ):
+ pass # torch_dtype is already a torch.dtype or "auto" or None
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
+ torch_dtype = getattr(torch, torch_dtype)
+ model_init_kwargs["torch_dtype"] = torch_dtype
+ else:
+ raise ValueError(
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
+ )
+ # Disable caching if gradient checkpointing is enabled (not supported)
+ model_init_kwargs["use_cache"] = (
+ False
+ if args.gradient_checkpointing
+ else model_init_kwargs.get("use_cache")
+ )
+ if "Qwen2-VL" in model_id:
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ model, **model_init_kwargs
+ )
+ elif "Qwen2.5-VL" in model_id:
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model, **model_init_kwargs
+ )
+ elif "Aria" in model_id:
+ model_init_kwargs.pop("use_cache")
+ model = AriaForConditionalGeneration.from_pretrained(
+ model, **model_init_kwargs
+ )
+ else:
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
+ else:
+ model_id = model.config._name_or_path
+ if args.model_init_kwargs is not None:
+ raise ValueError(
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
+ "This argument can only be used when the `model` argument is a string."
+ )
+
+ if peft_config is not None:
+ model = get_peft_model(model, peft_config)
+
+ # Reference model
+ if is_deepspeed_zero3_enabled():
+ if "Qwen2-VL" in model_id:
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ elif "Qwen2.5-VL" in model_id:
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ elif "Aria" in model_id:
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ else:
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_id, **model_init_kwargs
+ )
+ elif peft_config is None:
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
+ self.ref_model = create_reference_model(model)
+ else:
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
+ # to revert to the initial model.
+ self.ref_model = None
+
+ # Processing class
+ if processing_class is None:
+ if "Qwen" in model_id or "Aria" in model_id:
+ processing_class = AutoProcessor.from_pretrained(model_id)
+ pad_token_id = processing_class.tokenizer.pad_token_id
+ processing_class.pad_token_id = pad_token_id
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
+ if "Qwen" in model_id:
+ processing_class.image_processor.max_pixels = max_pixels
+ processing_class.image_processor.min_pixels = min_pixels
+ else:
+ processing_class = AutoTokenizer.from_pretrained(
+ model.config._name_or_path, padding_side="left"
+ )
+ pad_token_id = processing_class.pad_token_id
+
+ # Reward functions
+ if not isinstance(reward_funcs, list):
+ reward_funcs = [reward_funcs]
+ for i, reward_func in enumerate(reward_funcs):
+ if isinstance(reward_func, str):
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
+ reward_func, num_labels=1, **model_init_kwargs
+ )
+ self.reward_funcs = reward_funcs
+
+ # Reward processing class
+ if reward_processing_classes is None:
+ reward_processing_classes = [None] * len(reward_funcs)
+ elif not isinstance(reward_processing_classes, list):
+ reward_processing_classes = [reward_processing_classes]
+ else:
+ if len(reward_processing_classes) != len(reward_funcs):
+ raise ValueError(
+ "The number of reward processing classes must match the number of reward functions."
+ )
+
+ for i, (reward_processing_class, reward_func) in enumerate(
+ zip(reward_processing_classes, reward_funcs)
+ ):
+ if isinstance(reward_func, PreTrainedModel):
+ if reward_processing_class is None:
+ reward_processing_class = AutoTokenizer.from_pretrained(
+ reward_func.config._name_or_path
+ )
+ if reward_processing_class.pad_token_id is None:
+ reward_processing_class.pad_token = (
+ reward_processing_class.eos_token
+ )
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
+ reward_processing_classes[i] = reward_processing_class
+ self.reward_processing_classes = reward_processing_classes
+
+ # Data collator
+ def data_collator(features): # No data collation is needed in GRPO
+ return features
+
+ # Training arguments
+ self.max_prompt_length = args.max_prompt_length
+ self.max_completion_length = (
+ args.max_completion_length
+ ) # = |o_i| in the GRPO paper
+ self.num_generations = args.num_generations # = G in the GRPO paper
+ self.temporal = script_args.temporal
+ self.generation_config = GenerationConfig(
+ max_new_tokens=self.max_completion_length,
+ do_sample=True,
+ temperature=1, # HACK
+ num_return_sequences=self.num_generations,
+ pad_token_id=pad_token_id,
+ )
+ self.beta = args.beta
+
+ self.shuffled_num_generations = self.num_generations // 2
+ self.shuffled_generation_config = GenerationConfig(
+ max_new_tokens=self.max_completion_length,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=self.shuffled_num_generations,
+ pad_token_id=pad_token_id,
+ )
+
+ self.dummy_generation_config = GenerationConfig(
+ max_new_tokens=1,
+ do_sample=True,
+ top_p=0.95,
+ temperature=1, # HACK
+ num_return_sequences=1,
+ pad_token_id=pad_token_id,
+ )
+ self.len_control = script_args.len_control
+ self.beta = args.beta
+
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
+ # This acts as a flag to indicate that the warning has already been issued.
+ model.warnings_issued["estimate_tokens"] = True
+
+ # Initialize the metrics
+ self._metrics = defaultdict(list)
+ self.use_vllm = args.use_vllm
+
+
+ self.ds_infer = DSRunner(model_id="Qwen/Qwen2-0.5B-Instruct", gpu_id=7)
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ )
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
+ # self.model_accepts_loss_kwargs to False to enable scaling.
+ self.model_accepts_loss_kwargs = False
+
+ if self.use_vllm:
+ if not is_vllm_available():
+ raise ImportError(
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
+ "`pip install vllm` to use it."
+ )
+
+ if self.accelerator.is_main_process:
+ vllm_device = self.args.vllm_device
+ if vllm_device == "auto":
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
+
+ # ──────────────────── NEW BEGIN ────────────────────────
+ # Accept a comma-separated list, e.g. "cuda:6,7"
+ # device_tokens = [tok.strip() for tok in vllm_device.split(",")]
+ # multi_gpu = len(device_tokens) > 1
+
+ # if multi_gpu:
+ # # keep only the numeric part ("cuda:6" -> "6")
+ # # physical_ids = [tok.split(":")[1] for tok in device_tokens]
+ # physical_ids = [tok.split(":")[-1] for tok in device_tokens]
+
+ # # Mask visibility *in this process only* (rank-0)
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(physical_ids)
+
+ # logical_device = "cuda" # vLLM sees them as 0,1,…
+ # tensor_parallel_size = len(physical_ids)
+ # else:
+ # logical_device = vllm_device # single id like "cuda:6"
+ # tensor_parallel_size = 1
+
+ # vllm_device = logical_device
+ # ──────────────────── NEW END ────────────────────────
+
+
+ # Check that the requested device is available
+ '''
+ The first if statement below is to guard vllm errors'''
+ # if (not multi_gpu) and vllm_device.startswith("cuda:"):
+ # gpu_idx = int(vllm_device.split(":")[1])
+ # if gpu_idx >= torch.cuda.device_count():
+ # raise ValueError(
+ # f"The requested device {vllm_device} is not available. "
+ # f"You only have {torch.cuda.device_count()} GPUs."
+ # )
+
+ # # ---------- overlap-with-training warning (skip for multi-GPU) ---------
+ # if (not multi_gpu) and vllm_device in {
+ # f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
+ # }:
+ # warnings.warn(
+ # f"The requested vLLM device {vllm_device} is also used for training. "
+ # "This may lead to unexpected behaviour."
+ # )
+ if (
+ vllm_device.split(":")[0] == "cuda"
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
+ ):
+ raise ValueError(
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
+ )
+ # Check that the requested device is not also used for training
+ if vllm_device in {
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
+ }:
+ warnings.warn(
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
+ "behavior. It is recommended to use a dedicated device for vLLM."
+ )
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
+ # setting (profiling_patch).
+ # world_size_patch = patch(
+ # "torch.distributed.get_world_size", return_value=1
+ # )
+
+ '''
+ Below is the cahnged code
+ '''
+ # world_size_patch = patch(
+ # "torch.distributed.get_world_size", return_value=tensor_parallel_size
+ # )
+
+ # profiling_patch = patch(
+ # "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
+ # return_value=None,
+ # )
+ '''Above is the changed code'''
+
+ world_size_patch = patch(
+ "torch.distributed.get_world_size", return_value=1
+ )
+ profiling_patch = patch(
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
+ return_value=None,
+ )
+
+ '''
+ Below changes
+ '''
+ with world_size_patch, profiling_patch:
+ # with profiling_patch:
+ print("vllm is running on: ", vllm_device)
+ from vllm.config import ParallelConfig
+ self.llm = LLM(
+ model=model.name_or_path,
+ device=vllm_device,
+ # tensor_parallel_size=tensor_parallel_size, # ← 1 or N
+ # parallel_config=ParallelConfig( # ← NEW
+ # tensor_parallel_size=tensor_parallel_size
+ # ),
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
+ dtype=torch.bfloat16,
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
+ # This is particularly useful here because we generate completions from the same prompts.
+ enable_prefix_caching=True,
+ enforce_eager=True,
+ mm_processor_kwargs=(
+ {
+ "max_pixels": max_pixels,
+ "min_pixels": min_pixels,
+ }
+ # if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
+ if False
+ else None
+ ),
+ max_model_len=args.max_prompt_length + args.max_completion_length,
+ )
+ self.sampling_params = SamplingParams(
+ temperature=1.0,
+ top_p=0.95,
+ max_tokens=self.max_completion_length,
+ )
+
+ # self.second_sampling_params = SamplingParams(
+ # n = 1, # one generation
+ # temperature = 0.5, # less squeezing
+ # top_p = 0.9, # nucleus filter
+ # # top_k = 50, # (alternative to top_p)
+ # min_tokens = 4, # force at least 4 tokens
+ # max_tokens = self.max_completion_length,
+ # )
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
+
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
+ # synchronize all processes after vLLM has been fully initialized.
+ self.accelerator.wait_for_everyone()
+ else:
+ raise ValueError(
+ "GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
+ )
+
+ if self.ref_model is not None:
+ if self.is_deepspeed_enabled:
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
+
+ def _set_signature_columns_if_needed(self):
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
+ if self._signature_columns is None:
+ self._signature_columns = ["prompt"]
+
+ # Get the per-token log probabilities for the completions for the model and the reference model
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
+ # import pdb
+ # pdb.set_trace()
+ logits = model(input_ids, **kwargs).logits
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
+ per_token_logps = []
+ for logits_row, input_ids_row in zip(logits, input_ids):
+ log_probs = logits_row.log_softmax(dim=-1)
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
+ per_token_logps.append(token_log_prob)
+ return torch.stack(per_token_logps)
+
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
+ def _prepare_inputs(
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
+ ) -> dict[str, Union[torch.Tensor, Any]]:
+ return inputs
+
+ def remove_none_from_data(self, data):
+ for entry in data:
+ if "content" in entry and isinstance(entry["content"], list):
+ for sub_entry in entry["content"]:
+ if isinstance(sub_entry, dict):
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
+ for k in keys_to_remove:
+ del sub_entry[k]
+ return data
+
+
+
+ def compute_loss(
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
+ ):
+ if return_outputs:
+ raise ValueError("The GRPOTrainer does not support returning outputs")
+ # Compute the per-token log probabilities for the model
+
+
+ device = self.accelerator.device
+ prompts = [x["prompt"] for x in inputs]
+ # images = [x["image"] for x in inputs]
+ prompts_text = [
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
+ for example in inputs
+ ]
+
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
+
+ input_copy = self.remove_none_from_data(input_copy)
+
+ data_type = inputs[0]['data_type']
+
+ if data_type == 'image':
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
+ elif data_type == 'video':
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
+
+
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
+
+
+ prompt_inputs = self.processing_class(
+ text=copy.deepcopy(prompts_text),
+ images=image_inputs,
+ videos=video_inputs,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ add_special_tokens=False,
+ )
+
+ mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
+
+ if self.max_prompt_length is not None:
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
+
+
+ if self.temporal:
+ if video_inputs:
+ indices = torch.randperm(video_inputs[0].size(0))
+ shuffled_video_inputs = [video_inputs[0][indices]]
+ shuffled_prompt_inputs = self.processing_class(
+ text=copy.deepcopy(prompts_text),
+ images=image_inputs,
+ videos=shuffled_video_inputs,
+ return_tensors="pt",
+ padding=True,
+ padding_side="left",
+ add_special_tokens=False,
+ )
+ shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
+ if self.max_prompt_length is not None:
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
+ else:
+ shuffled_mm_data = [None]
+
+
+
+ if self.args.use_vllm:
+ # First, have main process load weights if needed
+ if self.state.global_step != self._last_loaded_step:
+ with unwrap_model_for_generation(
+ self.model,
+ self.accelerator,
+ gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
+ ) as unwrapped_model:
+ if is_compiled_module(unwrapped_model):
+ state_dict = unwrapped_model._orig_mod.state_dict()
+ else:
+ state_dict = unwrapped_model.state_dict()
+ if self.accelerator.is_main_process:
+ llm_model = (
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
+ )
+ # import pdb
+ # pdb.set_trace()
+ llm_model.load_weights(state_dict.items())
+ self._last_loaded_step = self.state.global_step
+
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
+ all_prompts_text = gather_object(prompts_text)
+ all_mm_data = gather_object(mm_data)
+ # group into pairs
+ all_multimodal_inputs = []
+
+ if self.temporal:
+ shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
+ shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
+ shuffled_all_multimodal_inputs = []
+
+ # 2. Refer to TobiasLee's implementation suggestions
+ # this is a better implementation for vLLM sampling.
+ for prompt, mm_item in zip(all_prompts_text, all_mm_data):
+ all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ for mm_item in shuffled_all_mm_data:
+ shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
+
+ # Create sampling params with num_generations
+ if self.accelerator.is_main_process:
+ # Clone to avoid modifying original params
+ sampling_params = copy.deepcopy(self.sampling_params)
+ sampling_params.n = self.num_generations
+ # Single generate call with all prompts
+ if self.accelerator.is_main_process:
+ outputs = self.llm.generate(
+ all_multimodal_inputs,
+ sampling_params=sampling_params,
+ use_tqdm=False,
+ )
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
+ completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ # Clone to avoid modifying original params
+ shuffled_sampling_params = copy.deepcopy(self.sampling_params)
+ shuffled_sampling_params.n = self.num_generations // 2
+ # Single generate call with all prompts
+ if self.accelerator.is_main_process:
+ shuffled_outputs = self.llm.generate(
+ shuffled_all_multimodal_inputs,
+ sampling_params=shuffled_sampling_params,
+ use_tqdm=False,
+ )
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
+ shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
+
+
+ else:
+ completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
+
+
+ # broadcast and slice
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
+ process_slice = slice(
+ self.accelerator.process_index * len(prompts) * self.num_generations,
+ (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
+ )
+ completion_ids = completion_ids[process_slice]
+
+ # Pad the completions, and concatenate them with the prompts
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
+ completion_ids = pad(
+ completion_ids, padding_value=self.processing_class.pad_token_id
+ )
+ prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+
+ prompt_length = prompt_ids.size(1)
+
+ # print('prompt_length:', prompt_length)
+
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
+ completion_ids = prompt_completion_ids[:, prompt_length:]
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
+
+
+ if self.temporal and shuffled_all_mm_data!=[]:
+ # broadcast and slice
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
+ process_id_list = []
+ for mm_item in shuffled_all_mm_data:
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
+
+ if video_inputs:
+ cur_shuffled_completion_ids = []
+ for i in range(len(process_id_list)):
+ if self.accelerator.process_index == process_id_list[i]:
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
+
+ # Pad the completions, and concatenate them with the prompts
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
+ cur_shuffled_completion_ids = pad(
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
+ )
+ shuffled_completion_ids = cur_shuffled_completion_ids
+
+
+ else:
+ raise ValueError("Only vLLM generation is supported in this version ")
+
+ # below are the same with yifan's code
+ # Mask everything after the first EOS token
+ is_eos = completion_ids == self.processing_class.eos_token_id
+ device = self.accelerator.device
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
+
+
+
+ prompt_inputs.pop("input_ids")
+ prompt_inputs.pop("attention_mask")
+
+ if data_type == 'image':
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
+ # import pdb; pdb.set_trace()
+
+
+ if data_type == 'video':
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
+ if 'second_per_grid_ts' in prompt_inputs:
+ del prompt_inputs["second_per_grid_ts"]
+
+ # import pdb
+ # pdb.set_trace()
+
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ with torch.inference_mode():
+ if self.ref_model is not None:
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
+ else:
+ with self.accelerator.unwrap_model(model).disable_adapter():
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
+
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if self.temporal and video_inputs:
+
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
+ if is_conversational(inputs[0]):
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
+
+ # Compute the rewards
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
+ for key in shuffled_reward_kwargs:
+ for example in inputs:
+ # Repeat each value in the column for `num_generations` times
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
+
+
+
+ # Decode the generated completions
+ completions = self.processing_class.batch_decode(
+ completion_ids, skip_special_tokens=True
+ )
+
+ '''
+ Below code is added for second round generation
+ '''
+ # second_stage_prompts_text = completions # ← list[str]
+ # curr_problem = example['problem']
+ # print('curr problem is: ', curr_problem)
+ # problem_key = "problem" if "problem" in inputs[0] else "question"
+ # # ─── For each sample in the batch, repeat its problem text num_generations times
+ # problems_aligned = [
+ # str(ex[problem_key])
+ # for ex in inputs
+ # for _ in range(self.num_generations)
+ # ]
+
+ # 1️⃣ descriptions extracted from first-round completions
+ second_stage_prompts_descriptions = [
+ str(extract_info(c) or "") # len = B * n_gen
+ for c in completions
+ ]
+
+ # 2️⃣ obtain + template the verify prompt for every sample,
+ # then repeat it n_gen times to align with descriptions
+ verify_templates = []
+ for ex in inputs: # B samples
+ tmpl = ex["verify_prompt"] # may be dict or str
+
+ # ▸ if it's still a dict, wrap it NOW
+ if not isinstance(tmpl, str):
+ tmpl = maybe_apply_chat_template(
+ tmpl, # conversation-dict
+ self.processing_class
+ )["prompt"] # templated string
+
+ verify_templates.extend([tmpl] * self.num_generations)
+
+ # 3️⃣ fill the {description} or {Description} slot
+ def fill_template(tmpl: str, desc: str) -> str:
+ # Replace both spelling variants and avoid all other {…} in the string
+ return (tmpl
+ .replace("{Description}", desc)
+ .replace("{description}", desc))
+
+ second_stage_chat_prompts = [
+ fill_template(tmpl, desc)
+ for tmpl, desc in zip(verify_templates, second_stage_prompts_descriptions)
+ ]
+
+ # 4️⃣ reward-model generation (DeepSpeed, GPU 7)
+ all_second_prompts_text = gather_object(second_stage_chat_prompts)
+
+ if self.accelerator.is_main_process:
+ sp = self.sampling_params
+ # • get num_generations completions per prompt
+ second_texts = self.reward_infer.generate(
+ all_second_prompts_text,
+ n=self.num_generations,
+ max_new=sp.max_tokens,
+ temp=sp.temperature,
+ top_p=sp.top_p,
+ )
+ second_completion_ids = [
+ self.reward_infer.tok.encode(t, add_special_tokens=False)
+ for t in second_texts
+ ]
+ else:
+ second_completion_ids = [None] * len(all_second_prompts_text) * self.num_generations
+
+
+ # 5️⃣ Broadcast / slice back to every process
+ second_completion_ids = broadcast_object_list(second_completion_ids, from_process=0)
+ process_slice2 = slice(
+ self.accelerator.process_index * len(second_stage_prompts_text) * self.num_generations,
+ (self.accelerator.process_index + 1) * len(second_stage_prompts_text) * self.num_generations,
+ )
+ second_completion_ids = second_completion_ids[process_slice2]
+
+ # 6️⃣ Pad & move to device-7
+ device = self.reward_infer.device
+ second_completion_ids = [torch.tensor(ids, device=device) for ids in second_completion_ids]
+ second_completion_ids = pad(
+ second_completion_ids, padding_value=self.processing_class.pad_token_id
+ )
+
+ # 7️⃣ Decode the second-round generations
+ second_completions = self.processing_class.batch_decode(
+ second_completion_ids, skip_special_tokens=True
+ )
+
+
+ print('Second completions: ')
+ print(second_completions[0])
+ print('*'*10)
+ time.sleep(40)
+
+ # 8️⃣ (Optional) wrap conversationally, log, or feed into further
+ # reward computation just like the first-round completions.
+ # For example:
+ # if is_conversational(inputs[0]):
+ # second_completions = [
+ # [{"role": "assistant", "content": c}] for c in second_completions
+ # ]
+
+ second_round_info = {
+ "second_prompts": second_stage_prompts_text, # list[str]
+ "second_completions": second_completions, # list[str]
+ }
+ '''
+ Above code is added for second round generation
+ '''
+
+
+
+
+ if is_conversational(inputs[0]):
+ completions = [
+ [{"role": "assistant", "content": completion}]
+ for completion in completions
+ ]
+
+ # Compute the rewards
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
+ rewards_per_func = torch.zeros(
+ len(prompts), len(self.reward_funcs), device=device
+ )
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ reward_kwargs = {
+ key: []
+ for key in inputs[0].keys()
+ if key not in ["prompt", "completion"]
+ }
+
+ # reward_kwargs.update(second_round_info)
+
+ for key in reward_kwargs:
+ for example in inputs:
+ # Repeat each value in the column for `num_generations` times
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
+
+ reward_kwargs["second_prompts"] = second_stage_prompts_text # len = len(completions)
+ reward_kwargs["second_completions"] = second_completions
+
+ output_reward_func = reward_func(
+ prompts=prompts, completions=completions, **reward_kwargs
+ )
+ rewards_per_func[:, i] = torch.tensor(
+ output_reward_func, dtype=torch.float32, device=device
+ )
+
+
+ # rewards_per_func = gather(rewards_per_func)
+ # # Sum the rewards from all reward functions
+ # rewards = rewards_per_func.sum(dim=1)
+
+ # process_slice = slice(
+ # self.accelerator.process_index * len(prompts),
+ # (self.accelerator.process_index + 1) * len(prompts),
+ # )
+
+ # rewards = rewards[process_slice]
+
+
+
+ if self.temporal and video_inputs:
+ temporal_rewards_per_func = rewards_per_func.clone()
+
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
+
+ if acc_mean >= 0.8 * shuffled_acc_mean:
+ mask = temporal_rewards_per_func[:, 0] > 0.1
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
+ else:
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
+ else:
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
+
+ # Sum the rewards from all reward functions
+ if self.temporal and video_inputs:
+ rewards = temporal_rewards_per_func.sum(dim=1)
+ else:
+ rewards = rewards_per_func.sum(dim=1)
+
+ if self.len_control:
+ mem_rewards = [0] * self.num_generations
+ mask = rewards_per_func[:, 0] > 0.1
+ lenth_list = completion_mask.sum(1)
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
+ # if len(selected_indices) > 1:
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
+ # N = len(sorted_items)
+ # for rank, (idx, length) in enumerate(sorted_items):
+ # reward = 0.2 - 0.2 * (rank / N)
+ # rewards[idx] += reward
+ # mem_rewards[idx] = reward
+ # for idx in range(len(lenth_list)):
+ # if lenth_list[idx] >= 512:
+ # rewards[idx] -= 0.5
+
+ if len(selected_indices) > 1:
+ for idx in selected_indices:
+ if 320 <= lenth_list[idx] <= 512:
+ rewards[idx] += 0.2
+
+ # print(rewards)
+ # print(completion_mask.sum(1))
+
+ # Compute grouped-wise rewards
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
+
+ # Normalize the rewards to compute the advantages
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
+
+ # x - x.detach() allows for preserving gradients from x
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
+ # per_token_loss = -per_token_loss
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+
+
+ # import pdb
+ # pdb.set_trace()
+
+ # Log the metrics
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
+ self._metrics["completion_length"].append(completion_length)
+
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
+ else:
+ reward_func_name = reward_func.__name__
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
+
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
+
+ num_devices = gathered_rewards.size(0) // self.num_generations
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
+ wrong_ratio = wrong_devices.sum().item() / num_devices
+
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
+ correct_ratio = correct_devices.sum().item() / num_devices
+
+ self._metrics["all_wrong"].append(wrong_ratio)
+ self._metrics["all_correct"].append(correct_ratio)
+
+ if self.temporal:
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
+
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
+
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
+
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
+
+
+ return loss
+
+
+
+
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
+
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
+ if next(iter(logs.keys())).startswith("eval_"):
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
+
+ logs = {**logs, **metrics}
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
+ super().log(logs, start_time)
+ else: # transformers<=4.46
+ super().log(logs)
+ self._metrics.clear()
\ No newline at end of file
diff --git a/src/r1-v/src/open_r1/utils/gpt_eval.py b/src/r1-v/src/open_r1/utils/gpt_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30d46ef9515ae11a2db550105e4370cc7ca8715
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/gpt_eval.py
@@ -0,0 +1,98 @@
+import os
+from openai import AzureOpenAI
+import time
+
+import base64
+from mimetypes import guess_type
+
+# Function to encode a local image into data URL
+def local_image_to_data_url(image_path):
+ # Guess the MIME type of the image based on the file extension
+ mime_type, _ = guess_type(image_path)
+ if mime_type is None:
+ mime_type = 'application/octet-stream' # Default MIME type if none is found
+
+ # Read and encode the image file
+ with open(image_path, "rb") as image_file:
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
+
+ # Construct the data URL
+ return f"data:{mime_type};base64,{base64_encoded_data}"
+
+
+def azure_gpt4(messages, model):
+ outputs = []
+ for message in messages:
+ input_prompt = [
+ { "role": "system", "content": "You are a helpful assistant." },
+ { "role": "user", "content": [
+ {
+ "type": "text",
+ "text": message["instruction"]
+ },
+ # {
+ # "type": "image_url",
+ # "image_url": {
+ # "url": message["image"]
+ # }
+ # }
+ ]}
+ ]
+ ## try N times if API exceed limit ...
+ for i in range(10):
+ try:
+ output = client.chat.completions.create(
+ model=model, messages=input_prompt, max_tokens=2000
+ )
+
+ output_text = output.choices[0].message.content
+ break ## exit if successful
+
+ except Exception as e:
+ print(f'Index {i} got error message: {e}')
+ output_text = ''
+ time.sleep(3)
+
+ outputs.append(output_text)
+
+ return outputs
+
+
+client = AzureOpenAI(
+ api_key = "83f30a2a22324395b854bd343db38d85",
+ api_version = "2024-08-01-preview",
+ azure_endpoint = "https://francecentral.api.cognitive.microsoft.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
+ )
+
+model = "gpt-4o"
+prompt_template = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. Provide your answer as a single final answer or a short phrase enclosed with . If the question is a multiple choice, the final answer should be a single letter choice. \nText description: {text}\nQuestion: {question}'''
+
+
+def infer(text, prompt_question):
+ prompt_question = prompt_question.replace('', '')
+ prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
+
+ messages = [
+ {"instruction": prompt},
+ ]
+ prompt_success = False
+ prompt_time = 0
+ outputs = [' None ']
+ while prompt_success == False and prompt_time <= 2:
+ try:
+ outputs = azure_gpt4(messages, model)
+ prompt_success = True
+ except:
+ prompt_time += 1
+ time.sleep(5)
+
+ return outputs[0]
+
+
+# info = '''The image is a geometric diagram of a circle with center R. The circle has five points labeled S, T, V, U, and R. Line segments RS and TU are drawn from the center R to the circumference, forming two angles. Lines SU and RV are also drawn from the center R, intersecting the circumference at points U and V, respectively. Point S is directly opposite point T on the circumference. The length of segment SU is given as 16.2. The line segments RS, TU, and RV are blue, indicating they are radii of the circle, which are all equal in length to the radius R of the circle. Points S and T are connected by a line segment, as are points U and V. The circle is centered at point R.'''
+# question = "What is the radius of the circle?"
+# print(infer(info, question))
+
+# # Another inference
+# question2 = "Which points are connected by a blue segment?"
+# print(infer(info, question2))
diff --git a/src/r1-v/src/open_r1/utils/llm_direct_eval.py b/src/r1-v/src/open_r1/utils/llm_direct_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..d193e71e7f9f394b8bf00dba61564c5607e5b392
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/llm_direct_eval.py
@@ -0,0 +1,50 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+
+# Set the model and device
+model_name = "Qwen/Qwen2.5-7B-Instruct"
+device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
+
+# Load model and tokenizer
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.float16, # Or use "auto" if supported by your setup
+ device_map={"": device.index} # Places model on cuda:7
+).to(device)
+
+prompt_template = '''You are an analytical assistant designed to evaluate texts and answer questions based on strict criteria. Follow these steps: Analyze the Text: Check if the provided text contains answers, solutions, explanations, problem-solving, or interpretations (e.g., reasoning steps, conclusions, causal statements like "because" or "therefore"). If any such elements exist, classify the text as non-descriptive. Determine Response: If the text is purely descriptive (e.g., objectively describing images, diagrams, or scenes without explanations/answers), answer the user's question using only the description in a single word or phrase enclosed with . If the text is non-descriptive, respond with Hacking Sample.\nText: {text}\nQuestion: {question}'''
+
+def infer(text, prompt_question):
+ prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
+ # Tokenize
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
+ # Generate
+ with torch.no_grad():
+ outputs = model.generate(
+ input_ids,
+ max_new_tokens=1024,
+ temperature=0.0,
+ top_k=1,
+ top_p=1.0,
+ do_sample=False,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
+ )
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # Remove prompt from output if echoed
+ if generated.startswith(prompt):
+ return generated[len(prompt):].lstrip()
+ return generated
+
+# Example usage
+
+
+# Example usage (can be called as many times as needed, fast!):
+info = '''The image is a geometric diagram of a circle with center R. The circle has five points labeled S, T, V, U, and R. Line segments RS and TU are drawn from the center R to the circumference, forming two angles. Lines SU and RV are also drawn from the center R, intersecting the circumference at points U and V, respectively. Point S is directly opposite point T on the circumference. The length of segment SU is given as 16.2. The line segments RS, TU, and RV are blue, indicating they are radii of the circle, which are all equal in length to the radius R of the circle. Points S and T are connected by a line segment, as are points U and V. The circle is centered at point R.'''
+question = "What is the radius of the circle?"
+print(infer(info, question))
+
+# Another inference
+question2 = "Which points are connected by a blue segment?"
+print(infer(info, question2))
diff --git a/src/r1-v/src/open_r1/utils/llm_eval.py b/src/r1-v/src/open_r1/utils/llm_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac2a181693bde595e3fd662bbf8aae71ab1891ee
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/llm_eval.py
@@ -0,0 +1,31 @@
+import ray
+from vllm import LLM, SamplingParams
+
+ray.init()
+
+@ray.remote(
+ num_gpus=1,
+ runtime_env={"env_vars": {"CUDA_VISIBLE_DEVICES": "7"}}
+)
+class VLLMActor:
+ def __init__(self):
+ import os
+ self.gpu = os.environ["CUDA_VISIBLE_DEVICES"]
+ self.prompt_template = '''You are an analytical assistant designed to evaluate texts and answer questions based on strict criteria. Follow these steps: Analyze the Text: Check if the provided text contains answers, solutions, explanations, problem-solving, or interpretations (e.g., reasoning steps, conclusions, causal statements like "because" or "therefore"). If any such elements exist, classify the text as non-descriptive. Determine Response: If the text is purely descriptive (e.g., objectively describing images, diagrams, or scenes without explanations/answers), answer the user's question using only the description in a single word or phrase enclosed with . If the text is non-descriptive, respond with Hacking Sample.\nText: {text}\nQuestion: {question}'''
+
+ def infer(self, text, prompt_question):
+ llm = LLM(
+ model="Qwen/Qwen2.5-7B-Instruct",
+ tensor_parallel_size=1,
+ max_model_len=2048,
+ gpu_memory_utilization=0.7,
+ )
+ sampling_params = SamplingParams(temperature=0.0, top_k=1, top_p=1.0, max_tokens=1024)
+ outputs = llm.generate([self.prompt_template.replace('{text}', text).replace('{question}', prompt_question)], sampling_params)
+ return outputs[0].outputs[0].text
+
+# actor = VLLMActor.remote()
+
+# info = '''The image is a geometric diagram of a circle with center R. The circle has five points labeled S, T, V, U, and R. Line segments RS and TU are drawn from the center R to the circumference, forming two angles. Lines SU and RV are also drawn from the center R, intersecting the circumference at points U and V, respectively. Point S is directly opposite point T on the circumference. The length of segment SU is given as 16.2. The line segments RS, TU, and RV are blue, indicating they are radii of the circle, which are all equal in length to the radius R of the circle. Points S and T are connected by a line segment, as are points U and V. The circle is centered at point R.'''
+# print(ray.get(actor.infer.remote(info)))
+
diff --git a/src/r1-v/src/open_r1/utils/math_cot.py b/src/r1-v/src/open_r1/utils/math_cot.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea2f4217fa020b86de8ee65f9220fe4c1bfb7a1
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/math_cot.py
@@ -0,0 +1,112 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Dict, List, Optional
+from mathruler.grader import extract_boxed_content, grade_answer
+
+
+def extract_info(predict: str) -> Optional[str]:
+ """
+ Extracts the content within ... tags.
+ Returns the inner text (with leading/trailing whitespace stripped),
+ or None if no tag is found.
+ """
+ match = re.search(r"([\s\S]*?)", predict, re.DOTALL)
+ if not match:
+ return None
+ return match.group(1).strip()
+
+def format_reward(predict: str) -> float:
+ # Define a pattern that requires:
+ # 1) …
+ # 2) …
+ # 3) …
+ # with optional whitespace between sections, and dot matching newlines.
+ pattern = re.compile(
+ r"^\s*[\s\S]+?\s*"
+ r"[\s\S]+?\s*"
+ r"[\s\S]+?\s*$",
+ re.DOTALL
+ )
+ return 1.0 if pattern.match(predict) else 0.0
+
+
+def extract_math_answer(text: str) -> str:
+ """
+ 1) Try the full … block.
+ 2) If that is missing, grab whatever follows the opening tag.
+ 3) Otherwise return the original text.
+ """
+ # ① normal case …
+ m = re.search(r'\s*(.*?)\s*', text, flags=re.DOTALL | re.IGNORECASE)
+ if m:
+ return m.group(1).strip()
+
+ # ② fallback …
+ m = re.search(r'\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
+ if m:
+ return m.group(1).strip()
+
+ # ③ nothing found
+ return text.strip()
+
+def single_accuracy_reward(predict: str, ground_truth: str) -> float:
+ # answer = extract_boxed_content(predict)
+ # print('Predict: ')
+ # print(predict)
+ # print('Sol')
+ # print(ground_truth)
+ # print('-'*20)
+ # answer = extract_math_answer(predict)
+ answer = predict
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
+
+def math_accuracy_reward(predict: str, ground_truth: str) -> float:
+ # answer = extract_boxed_content(predict)
+ print('Predict: ')
+ print(predict)
+ print('Sol')
+ print(ground_truth)
+ print('-'*20)
+ answer = extract_math_answer(predict)
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
+
+
+
+def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
+ scores = []
+ for predict, ground_truth in zip(predicts, ground_truths):
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
+ format_score = format_reward(predict)
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
+ scores.append(
+ {
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
+ "format": format_score,
+ "accuracy": accuracy_score,
+ }
+ )
+
+ return scores
+
+
+def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
+ # format_score = format_reward(predict)
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
+
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
+ return accuracy_score
+
diff --git a/src/r1-v/src/open_r1/utils/math_cot_noInfo.py b/src/r1-v/src/open_r1/utils/math_cot_noInfo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b669548036ef50f6fd5dc6b82407cf7194d64794
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/math_cot_noInfo.py
@@ -0,0 +1,81 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Dict, List, Optional
+
+from mathruler.grader import extract_boxed_content, grade_answer
+
+
+def format_reward(predict: str) -> float:
+ # Define a pattern that requires:
+ # 1) …
+ # 2) …
+ # 3) …
+ # with optional whitespace between sections, and dot matching newlines.
+ pattern = re.compile(
+ r"[\s\S]+?\s*"
+ r"[\s\S]+?\s*$",
+ re.DOTALL
+ )
+ return 1.0 if pattern.match(predict) else 0.0
+
+
+def extract_answer(predict: str) -> Optional[str]:
+ """
+ Extracts the content of the … block from `predict`.
+ Returns the inner text (with leading/trailing whitespace stripped),
+ or None if no tag is found.
+ """
+ match = re.search(r"([\s\S]*?)", predict, re.DOTALL)
+ if not match:
+ return None
+ return match.group(1).strip()
+
+
+def accuracy_reward(predict: str, ground_truth: str) -> float:
+ # answer = extract_boxed_content(predict)
+ print('Predict: ')
+ print(predict)
+ print('Sol')
+ print(ground_truth)
+ print('-'*20)
+ answer = extract_answer(predict)
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
+
+
+def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
+ scores = []
+ for predict, ground_truth in zip(predicts, ground_truths):
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
+ format_score = format_reward(predict)
+ accuracy_score = accuracy_reward(predict, ground_truth)
+ scores.append(
+ {
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
+ "format": format_score,
+ "accuracy": accuracy_score,
+ }
+ )
+
+ return scores
+
+
+def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
+ format_score = format_reward(predict)
+ accuracy_score = accuracy_reward(predict, ground_truth)
+
+ return (1 - format_weight) * accuracy_score + format_weight * format_score
+
diff --git a/src/r1-v/src/open_r1/utils/self_eval.py b/src/r1-v/src/open_r1/utils/self_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..35fa5b6193ec494baf1567bacc01fc22c1ff375a
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/self_eval.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python
+"""
+Offline batched generation for Qwen-2.5 with vLLM.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=0,1 python qwen25_vllm_offline.py
+"""
+from typing import List
+import os
+
+from vllm import LLM, SamplingParams
+from transformers import AutoTokenizer
+
+# ▶ 1. Which checkpoint?
+# Any base / chat / instruct variant works. Example: 7-B Chat.
+MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
+# MODEL_ID = "Video-R1/Video-R1-7B"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = "7"
+# ▶ 2. How many GPUs to shard across?
+VISIBLE = os.environ.get("CUDA_VISIBLE_DEVICES", "7")
+TP = len(VISIBLE.split(",")) # tensor-parallel size
+
+# ▶ 3. Create the vLLM engine once
+llm = LLM(
+ model=MODEL_ID,
+ tensor_parallel_size=TP,
+ gpu_memory_utilization=0.80, # leave 10 % head-room
+ trust_remote_code=True, # Qwen needs this
+ max_model_len=32768, # full Qwen2.5 context window
+)
+
+# ▶ 4. Tokenizer, used only for chat templating
+tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
+
+def _make_chat(text: str) -> str:
+ """Wrap raw user text in ChatML so Qwen answers correctly."""
+ messages = [{"role": "user", "content": text}]
+ return tok.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+
+def generate_batch(
+ prompts: List[str],
+ temperature: float = 0.4,
+ top_p: float = 0.8,
+ max_tokens: int = 1024,
+) -> List[str]:
+ """
+ Generate a single completion for every prompt in *prompts* and
+ return them as a list of strings (same order).
+ """
+ # 1. Convert each raw prompt into a chat-formatted string
+ chat_prompts = [_make_chat(p) for p in prompts]
+
+ # 2. Typical Qwen2.5 sampling settings
+ params = SamplingParams(
+ temperature=temperature,
+ top_p=top_p,
+ max_tokens=max_tokens,
+ )
+
+ # 3. Run vLLM. Each RequestOutput can hold n>1 candidates; we take the first
+ outputs = llm.generate(chat_prompts, params)
+ return [out.outputs[0].text for out in outputs]
+
+
+print(generate_batch(['Hi, how are you?', 'Describe dog in one sentence.']))
\ No newline at end of file
diff --git a/src/r1-v/src/open_r1/utils/utils.py b/src/r1-v/src/open_r1/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6112f0c8c5321a1b1aaa64fbab2f709566633215
--- /dev/null
+++ b/src/r1-v/src/open_r1/utils/utils.py
@@ -0,0 +1,147 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import importlib.resources as pkg_resources
+import json
+import random
+import itertools
+import warnings
+from collections import deque
+from dataclasses import dataclass, field
+from importlib.metadata import version
+from typing import Any, Literal, Optional, Union
+
+import datasets
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+import torch.utils.data
+from accelerate import Accelerator, PartialState
+from accelerate.state import AcceleratorState
+from huggingface_hub import ModelCard, ModelCardData
+from rich.console import Console
+from rich.table import Table
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import IterableDataset
+from transformers import (
+ BitsAndBytesConfig,
+ DataCollatorForLanguageModeling,
+ EvalPrediction,
+ GenerationConfig,
+ PreTrainedTokenizerBase,
+ TrainerState,
+ TrainingArguments,
+ is_comet_available,
+)
+from transformers.utils import (
+ is_peft_available,
+ is_torch_mlu_available,
+ is_torch_npu_available,
+ is_torch_xpu_available,
+)
+
+
+def get_all_parameters(sub_module, recurse=False):
+ return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
+
+
+def iter_params(module, recurse=False):
+ return [param for _, param in get_all_parameters(module, recurse)]
+
+def remove_hooks(model: "DeepSpeedEngine") -> None:
+ """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
+ return
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
+ optimizer_offload = model.optimizer.parameter_offload
+ elif model.optimizer is not None:
+ optimizer_offload = model.optimizer
+
+ for param in iter_params(optimizer_offload.module, recurse=True):
+ param.ds_active_sub_modules.clear()
+
+ for hook in optimizer_offload.forward_hooks:
+ hook.remove()
+ for hook in optimizer_offload.backward_hooks:
+ hook.remove()
+
+ optimizer_offload.forward_hooks = []
+ optimizer_offload.backward_hooks = []
+
+def add_hooks(model: "DeepSpeedEngine") -> None:
+ """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
+ return
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
+ optimizer_offload = model.optimizer.parameter_offload
+ elif model.optimizer is not None:
+ optimizer_offload = model.optimizer
+ optimizer_offload._register_hooks_recursively(optimizer_offload.module)
+
+
+
+
+
+
+def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor:
+ """
+ Pads a list of tensors to the same shape along the first dimension.
+
+ Args:
+ tensors (`list[torch.Tensor]`):
+ List of input tensors to pad.
+ padding_value (`int`):
+ Value to use for padding. Default is 0.
+ padding_side (`str`):
+ Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
+
+ Returns:
+ `torch.Tensor`:
+ A single tensor containing the padded tensors.
+
+ Examples:
+ >>> import torch
+ >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
+ tensor([[1, 2, 3],
+ [4, 5, 0]])
+ >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])])
+ tensor([[[1, 2],
+ [3, 4]],
+
+ [[5, 6],
+ [0, 0]]])
+ """
+ # Determine the maximum shape for each dimension
+ output_shape = np.max([t.shape for t in tensors], 0).tolist()
+
+ # Create an output tensor filled with the padding value
+ output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
+
+ for i, t in enumerate(tensors):
+ # Determine the slice for the sequence dimension
+ if padding_side == "left":
+ seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
+ elif padding_side == "right":
+ seq_slice = slice(0, t.shape[0])
+ else:
+ raise ValueError("padding_side must be 'left' or 'right'")
+
+ slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
+ output[i][slices] = t
+
+ return output
+
+
diff --git a/src/r1-v/src/r1_v.egg-info/PKG-INFO b/src/r1-v/src/r1_v.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..1475557e5eb06fe22415fff3c00ad2ee281c1042
--- /dev/null
+++ b/src/r1-v/src/r1_v.egg-info/PKG-INFO
@@ -0,0 +1,59 @@
+Metadata-Version: 2.4
+Name: r1-v
+Version: 0.1.0
+Summary: R1-V
+Home-page: https://github.com/Deep-Agent/R1-V
+Author: The r1-v team and the Hugging Face team (past and future)
+License: Apache
+Classifier: Development Status :: 3 - Alpha
+Classifier: Intended Audience :: Developers
+Classifier: Intended Audience :: Education
+Classifier: Intended Audience :: Science/Research
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
+Requires-Python: >=3.10.9
+License-File: LICENSE
+Requires-Dist: accelerate>=1.2.1
+Requires-Dist: bitsandbytes>=0.43.0
+Requires-Dist: einops>=0.8.0
+Requires-Dist: datasets>=3.2.0
+Requires-Dist: deepspeed==0.15.4
+Requires-Dist: hf_transfer>=0.1.4
+Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
+Requires-Dist: liger_kernel==0.5.2
+Requires-Dist: packaging>=23.0
+Requires-Dist: safetensors>=0.3.3
+Requires-Dist: sentencepiece>=0.1.99
+Requires-Dist: trl==0.16.0
+Provides-Extra: tests
+Requires-Dist: pytest; extra == "tests"
+Requires-Dist: parameterized>=0.9.0; extra == "tests"
+Provides-Extra: torch
+Requires-Dist: torch>=2.5.1; extra == "torch"
+Provides-Extra: quality
+Requires-Dist: black>=24.4.2; extra == "quality"
+Requires-Dist: isort>=5.12.0; extra == "quality"
+Requires-Dist: flake8>=6.0.0; extra == "quality"
+Provides-Extra: eval
+Requires-Dist: lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math] ; extra == "eval"
+Requires-Dist: math-verify; extra == "eval"
+Provides-Extra: dev
+Requires-Dist: black>=24.4.2; extra == "dev"
+Requires-Dist: isort>=5.12.0; extra == "dev"
+Requires-Dist: flake8>=6.0.0; extra == "dev"
+Requires-Dist: pytest; extra == "dev"
+Requires-Dist: parameterized>=0.9.0; extra == "dev"
+Requires-Dist: lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math] ; extra == "dev"
+Requires-Dist: math-verify; extra == "dev"
+Dynamic: author
+Dynamic: classifier
+Dynamic: home-page
+Dynamic: license
+Dynamic: license-file
+Dynamic: provides-extra
+Dynamic: requires-dist
+Dynamic: requires-python
+Dynamic: summary
diff --git a/src/r1-v/src/r1_v.egg-info/SOURCES.txt b/src/r1-v/src/r1_v.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a56a39892dea179c1e9745091dfa20eb751c6edf
--- /dev/null
+++ b/src/r1-v/src/r1_v.egg-info/SOURCES.txt
@@ -0,0 +1,25 @@
+LICENSE
+setup.cfg
+setup.py
+src/open_r1/__init__.py
+src/open_r1/evaluate.py
+src/open_r1/generate.py
+src/open_r1/grpo-cot-LLMEval.py
+src/open_r1/grpo-cot-answerBERT-eval.py
+src/open_r1/grpo-cot-noDesEval.py
+src/open_r1/grpo-cot-noInfo.py
+src/open_r1/grpo-cot-selfEval.py
+src/open_r1/grpo-cot.py
+src/open_r1/grpo.py
+src/open_r1/sft_video.py
+src/open_r1/trainer/__init__.py
+src/open_r1/trainer/grpo_trainer.py
+src/open_r1/trainer/vllm_grpo_trainer_modified.py
+src/open_r1/trainer/vllm_grpo_trainer_modified_error.py
+src/open_r1/trainer/vllm_grpo_trainer_modified_orig.py
+src/r1_v.egg-info/PKG-INFO
+src/r1_v.egg-info/SOURCES.txt
+src/r1_v.egg-info/dependency_links.txt
+src/r1_v.egg-info/not-zip-safe
+src/r1_v.egg-info/requires.txt
+src/r1_v.egg-info/top_level.txt
\ No newline at end of file
diff --git a/src/r1-v/src/r1_v.egg-info/dependency_links.txt b/src/r1-v/src/r1_v.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/r1-v/src/r1_v.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/src/r1-v/src/r1_v.egg-info/not-zip-safe b/src/r1-v/src/r1_v.egg-info/not-zip-safe
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/r1-v/src/r1_v.egg-info/not-zip-safe
@@ -0,0 +1 @@
+
diff --git a/src/r1-v/src/r1_v.egg-info/requires.txt b/src/r1-v/src/r1_v.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..228508908975935d84e93ad7399a29dfb8057f7c
--- /dev/null
+++ b/src/r1-v/src/r1_v.egg-info/requires.txt
@@ -0,0 +1,37 @@
+accelerate>=1.2.1
+bitsandbytes>=0.43.0
+einops>=0.8.0
+datasets>=3.2.0
+deepspeed==0.15.4
+hf_transfer>=0.1.4
+huggingface-hub[cli]<1.0,>=0.19.2
+liger_kernel==0.5.2
+packaging>=23.0
+safetensors>=0.3.3
+sentencepiece>=0.1.99
+trl==0.16.0
+
+[dev]
+black>=24.4.2
+isort>=5.12.0
+flake8>=6.0.0
+pytest
+parameterized>=0.9.0
+lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]
+math-verify
+
+[eval]
+lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]
+math-verify
+
+[quality]
+black>=24.4.2
+isort>=5.12.0
+flake8>=6.0.0
+
+[tests]
+pytest
+parameterized>=0.9.0
+
+[torch]
+torch>=2.5.1
diff --git a/src/r1-v/src/r1_v.egg-info/top_level.txt b/src/r1-v/src/r1_v.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e6b7848c0a475251df4011ac84c4c13d168c37e5
--- /dev/null
+++ b/src/r1-v/src/r1_v.egg-info/top_level.txt
@@ -0,0 +1 @@
+open_r1