Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- eval/test_grounding_r1_nothink_ss.py +352 -0
- open-r1-multimodal/configs/qwen2vl_sft_config.yaml +42 -0
- open-r1-multimodal/data_config/gui_grounding.yaml +2 -0
- open-r1-multimodal/data_config/rec_internvl.yaml +4 -0
- open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json +3 -0
- open-r1-multimodal/local_scripts/create_vision_cot_data.py +153 -0
- open-r1-multimodal/local_scripts/zero3.yaml +22 -0
- open-r1-multimodal/setup.py +137 -0
- open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt +32 -0
- open-r1-multimodal/src/open_r1.egg-info/not-zip-safe +1 -0
- open-r1-multimodal/src/open_r1/grpo.py +214 -0
- open-r1-multimodal/src/open_r1/sft.py +346 -0
- open-r1-multimodal/src/open_r1/trainer/__init__.py +5 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/vllm_grpo_trainer.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/utils/__pycache__/math.cpython-310.pyc.139714633805856 +0 -0
- open-r1-multimodal/src/open_r1/utils/__pycache__/math.cpython-310.pyc.140170314805280 +0 -0
- open-r1-multimodal/src/open_r1/vlm_modules/__pycache__/internvl_module.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/vlm_modules/__pycache__/qwen_module.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py +238 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json filter=lfs diff=lfs merge=lfs -text
|
eval/test_grounding_r1_nothink_ss.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import random
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 19 |
+
|
| 20 |
+
def setup_distributed():
|
| 21 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 22 |
+
torch.cuda.set_device(local_rank)
|
| 23 |
+
|
| 24 |
+
dist.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
return local_rank, world_size, rank
|
| 30 |
+
|
| 31 |
+
local_rank, world_size, rank = setup_distributed()
|
| 32 |
+
device = f"cuda:{local_rank}"
|
| 33 |
+
print(f"Process {rank} using {device}")
|
| 34 |
+
|
| 35 |
+
steps = 3800
|
| 36 |
+
if rank == 0:
|
| 37 |
+
print("Steps: ", steps)
|
| 38 |
+
#RUN_NAME = "base"
|
| 39 |
+
RUN_NAME = "Qwen2.5-VL-7B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_filtered_only_one_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e"
|
| 40 |
+
#MODEL_PATH="/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-7B-Instruct"
|
| 41 |
+
MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
|
| 42 |
+
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
| 43 |
+
|
| 44 |
+
BSZ=32
|
| 45 |
+
DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot/metadata"
|
| 46 |
+
|
| 47 |
+
TEST_DATASETS = ['hf_test_full']
|
| 48 |
+
IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot/images"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# TEST_DATASETS = ['lisa_test']
|
| 52 |
+
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 56 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 57 |
+
MODEL_PATH,
|
| 58 |
+
torch_dtype=torch.bfloat16,
|
| 59 |
+
attn_implementation="flash_attention_2",
|
| 60 |
+
device_map={"": local_rank},
|
| 61 |
+
)
|
| 62 |
+
# default processer
|
| 63 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH,max_pixels=2007040,min_pixels=3136)
|
| 64 |
+
# processor.image_processor.min_pixels=3136
|
| 65 |
+
# processor.image_processor.max_pixels=2007040
|
| 66 |
+
print(processor.image_processor.min_pixels)
|
| 67 |
+
print(processor.image_processor.max_pixels)
|
| 68 |
+
# def extract_point_answer(content):
|
| 69 |
+
# # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
| 70 |
+
# answer_tag_pattern = r'<answer>(.*?)</answer>'
|
| 71 |
+
# content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 72 |
+
# if content_answer_match:
|
| 73 |
+
# content_answer = content_answer_match.group(1).strip()
|
| 74 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
|
| 75 |
+
# if tool_call_match:
|
| 76 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 77 |
+
# # 解析 JSON
|
| 78 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 79 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 80 |
+
# coordinate = arguments.get("coordinate", None)
|
| 81 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 82 |
+
# x, y = coordinate
|
| 83 |
+
# extracted_coordinate = [x, y]
|
| 84 |
+
# return extracted_coordinate
|
| 85 |
+
# return [0, 0]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# def extract_point_answer(content):
|
| 89 |
+
# # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 90 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 91 |
+
# if tool_call_match:
|
| 92 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 93 |
+
# # 首先尝试将 tool_call_content 解析为 JSON
|
| 94 |
+
# try:
|
| 95 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 96 |
+
# print(tool_call_json)
|
| 97 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 98 |
+
# coordinate = arguments.get("coordinate", None)
|
| 99 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 100 |
+
# try:
|
| 101 |
+
# x = float(coordinate[0])
|
| 102 |
+
# y = float(coordinate[1])
|
| 103 |
+
# return [x, y]
|
| 104 |
+
# except (ValueError, TypeError):
|
| 105 |
+
# pass # 如果转换失败,继续尝试正则提取
|
| 106 |
+
# except json.JSONDecodeError:
|
| 107 |
+
# pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 108 |
+
# # 回退到正则表达式提取两个数字
|
| 109 |
+
# numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 110 |
+
# if len(numbers) >= 2:
|
| 111 |
+
# x = float(numbers[-2])
|
| 112 |
+
# y = float(numbers[-1])
|
| 113 |
+
# return [x, y]
|
| 114 |
+
# return [0, 0]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def extract_point_answer(content):
|
| 119 |
+
# 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 120 |
+
tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 121 |
+
if tool_call_match:
|
| 122 |
+
tool_call_content = tool_call_match.group(1).strip()
|
| 123 |
+
# 首先尝试将 tool_call_content 解析为 JSON
|
| 124 |
+
try:
|
| 125 |
+
numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 126 |
+
if len(numbers) >= 2:
|
| 127 |
+
x = float(numbers[-2])
|
| 128 |
+
y = float(numbers[-1])
|
| 129 |
+
return [x, y]
|
| 130 |
+
except json.JSONDecodeError:
|
| 131 |
+
pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 132 |
+
# 回退到正则表达式提取两个数字
|
| 133 |
+
return [0, 0]
|
| 134 |
+
|
| 135 |
+
def point_in_box(point, box):
|
| 136 |
+
x,y = point
|
| 137 |
+
if box[0] <= x < box[2] and box[1] <= y < box[3]:
|
| 138 |
+
return 1
|
| 139 |
+
else:
|
| 140 |
+
return 0
|
| 141 |
+
|
| 142 |
+
num_samples = 2000
|
| 143 |
+
num_all_sample = 0
|
| 144 |
+
num_desktop_sample = 0
|
| 145 |
+
num_mobile_sample = 0
|
| 146 |
+
num_web_sample = 0
|
| 147 |
+
num_correct_sample = 0
|
| 148 |
+
for ds in TEST_DATASETS:
|
| 149 |
+
if rank == 0:
|
| 150 |
+
print(f"Processing {ds}...")
|
| 151 |
+
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
| 152 |
+
data = json.load(open(ds_path, "r"))
|
| 153 |
+
random.seed(42)
|
| 154 |
+
random.shuffle(data)
|
| 155 |
+
data = data[:num_samples]
|
| 156 |
+
|
| 157 |
+
# Split data for distributed evaluation
|
| 158 |
+
per_rank_data = len(data) // world_size
|
| 159 |
+
start_idx = rank * per_rank_data
|
| 160 |
+
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
| 161 |
+
rank_data = data[start_idx:end_idx]
|
| 162 |
+
|
| 163 |
+
messages = []
|
| 164 |
+
|
| 165 |
+
for x in rank_data:
|
| 166 |
+
image_path = os.path.join(IMAGE_ROOT, x['img_url'])
|
| 167 |
+
width,height = x['img_size'][0],x['img_size'][1]
|
| 168 |
+
resized_height, resized_width = smart_resize(
|
| 169 |
+
height,
|
| 170 |
+
width,
|
| 171 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 172 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 173 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 174 |
+
)
|
| 175 |
+
system_content = """You are a helpful assistant.
|
| 176 |
+
#Tools
|
| 177 |
+
|
| 178 |
+
You may call one or more functions to assist with the user query.
|
| 179 |
+
|
| 180 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 181 |
+
<tools>
|
| 182 |
+
{"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
|
| 183 |
+
</tools>
|
| 184 |
+
|
| 185 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 186 |
+
<tool_call>
|
| 187 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 188 |
+
</tool_call>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
|
| 189 |
+
message = [
|
| 190 |
+
{
|
| 191 |
+
"role": "system",
|
| 192 |
+
"content": [
|
| 193 |
+
{
|
| 194 |
+
"type": "text",
|
| 195 |
+
"text": system_content
|
| 196 |
+
}
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"role": "user",
|
| 201 |
+
"content": [
|
| 202 |
+
{
|
| 203 |
+
"type": "image",
|
| 204 |
+
"image": f"file://{image_path}"
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"type": "text",
|
| 208 |
+
"text": x['task']
|
| 209 |
+
}
|
| 210 |
+
]
|
| 211 |
+
},
|
| 212 |
+
]
|
| 213 |
+
# print(message)
|
| 214 |
+
messages.append(message)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
rank_outputs = [] # List to store answers for this rank
|
| 218 |
+
all_outputs = [] # List to store all answers
|
| 219 |
+
|
| 220 |
+
# Process data
|
| 221 |
+
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
| 222 |
+
batch_messages = messages[i:i + BSZ]
|
| 223 |
+
|
| 224 |
+
# Preparation for inference
|
| 225 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 226 |
+
|
| 227 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 228 |
+
inputs = processor(
|
| 229 |
+
text=text,
|
| 230 |
+
images=image_inputs,
|
| 231 |
+
videos=video_inputs,
|
| 232 |
+
padding=True,
|
| 233 |
+
padding_side="left",
|
| 234 |
+
return_tensors="pt",
|
| 235 |
+
)
|
| 236 |
+
inputs = inputs.to(device)
|
| 237 |
+
|
| 238 |
+
# Inference: Generation of the output
|
| 239 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 240 |
+
|
| 241 |
+
generated_ids_trimmed = [
|
| 242 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 243 |
+
]
|
| 244 |
+
batch_output_text = processor.batch_decode(
|
| 245 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
rank_outputs.extend(batch_output_text)
|
| 249 |
+
|
| 250 |
+
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
| 251 |
+
|
| 252 |
+
# Gather all outputs from all ranks
|
| 253 |
+
all_outputs = [None] * len(data)
|
| 254 |
+
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
| 255 |
+
|
| 256 |
+
gathered_results = [None] * world_size
|
| 257 |
+
dist.all_gather_object(gathered_results, rank_results)
|
| 258 |
+
|
| 259 |
+
assert gathered_results[-1][-1][0] == len(data) - 1
|
| 260 |
+
|
| 261 |
+
# The main process will collect all results
|
| 262 |
+
if rank == 0:
|
| 263 |
+
for results in gathered_results:
|
| 264 |
+
for idx, output in results:
|
| 265 |
+
assert idx < len(all_outputs)
|
| 266 |
+
all_outputs[idx] = output
|
| 267 |
+
assert all_outputs[-1] is not None
|
| 268 |
+
|
| 269 |
+
final_output = []
|
| 270 |
+
correct_number = 0
|
| 271 |
+
correct_number_desktop = 0
|
| 272 |
+
correct_number_mobile = 0
|
| 273 |
+
correct_number_web = 0
|
| 274 |
+
|
| 275 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 276 |
+
original_output = model_output
|
| 277 |
+
ground_truth = input_example['bbox']
|
| 278 |
+
split_class = input_example['split']
|
| 279 |
+
ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], (ground_truth[0]+ground_truth[2]) / input_example['img_size'][0], (ground_truth[1]+ground_truth[3]) / input_example['img_size'][1]]
|
| 280 |
+
model_answer = extract_point_answer(original_output)
|
| 281 |
+
resized_height, resized_width = smart_resize(
|
| 282 |
+
input_example['img_size'][1],
|
| 283 |
+
input_example['img_size'][0],
|
| 284 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 285 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 286 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 287 |
+
)
|
| 288 |
+
model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
|
| 289 |
+
# Count correct answers
|
| 290 |
+
correct = 0
|
| 291 |
+
if model_answer is not None:
|
| 292 |
+
correct = point_in_box(model_answer, ground_truth)
|
| 293 |
+
correct_number += correct
|
| 294 |
+
num_all_sample +=1
|
| 295 |
+
num_correct_sample += correct
|
| 296 |
+
if split_class == "desktop":
|
| 297 |
+
correct_number_desktop += correct
|
| 298 |
+
num_desktop_sample += 1
|
| 299 |
+
if split_class == "mobile":
|
| 300 |
+
correct_number_mobile += correct
|
| 301 |
+
num_mobile_sample += 1
|
| 302 |
+
if split_class == "web":
|
| 303 |
+
correct_number_web += correct
|
| 304 |
+
num_web_sample += 1
|
| 305 |
+
# Create a result dictionary for this example
|
| 306 |
+
result = {
|
| 307 |
+
'image': input_example['img_url'],
|
| 308 |
+
'question': input_example['task'],
|
| 309 |
+
'resized_size': [resized_height, resized_width],
|
| 310 |
+
'ground_truth': ground_truth,
|
| 311 |
+
'model_output': original_output,
|
| 312 |
+
'extracted_answer': model_answer,
|
| 313 |
+
'correct': correct
|
| 314 |
+
}
|
| 315 |
+
final_output.append(result)
|
| 316 |
+
|
| 317 |
+
# Calculate and print accuracy
|
| 318 |
+
accuracy = correct_number / len(data) * 100
|
| 319 |
+
accuracy_desktop = correct_number_desktop / num_desktop_sample * 100
|
| 320 |
+
accuracy_mobile = correct_number_mobile / num_mobile_sample * 100
|
| 321 |
+
accuracy_web = correct_number_web / num_web_sample * 100
|
| 322 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 323 |
+
print(f"Accuracy of desktop: {accuracy_desktop:.2f}%")
|
| 324 |
+
print(f"Accuracy of mobile: {accuracy_mobile:.2f}%")
|
| 325 |
+
print(f"Accuracy of web: {accuracy_web:.2f}%")
|
| 326 |
+
|
| 327 |
+
# Save results to a JSON file
|
| 328 |
+
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
| 329 |
+
output_dir = os.path.dirname(output_path)
|
| 330 |
+
if not os.path.exists(output_dir):
|
| 331 |
+
os.makedirs(output_dir)
|
| 332 |
+
with open(output_path, "w") as f:
|
| 333 |
+
json.dump({
|
| 334 |
+
'accuracy': accuracy,
|
| 335 |
+
'results': final_output
|
| 336 |
+
}, f, indent=2)
|
| 337 |
+
|
| 338 |
+
print(f"Results saved to {output_path}")
|
| 339 |
+
print("-"*100)
|
| 340 |
+
# 将最后的统计和打印移到rank==0的条件块内
|
| 341 |
+
if rank == 0:
|
| 342 |
+
accuracy = num_correct_sample / num_all_sample * 100
|
| 343 |
+
print(f"\nnumber of correct samples: {num_correct_sample}")
|
| 344 |
+
print(f"number of all samples: {num_all_sample}")
|
| 345 |
+
print(f"Accuracy of all datasets: {accuracy:.2f}%")
|
| 346 |
+
|
| 347 |
+
# Synchronize all processes
|
| 348 |
+
dist.barrier()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
|
open-r1-multimodal/configs/qwen2vl_sft_config.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model arguments
|
| 2 |
+
model_name_or_path: /data/shz/ckpt/Qwen2.5-VL-3B-Instruct
|
| 3 |
+
model_revision: main
|
| 4 |
+
torch_dtype: bfloat16
|
| 5 |
+
|
| 6 |
+
# Data training arguments
|
| 7 |
+
dataset_name: /data/shz/project/vlm-r1/VLM-R1/src/open-r1-multimodal/data_script/rec.yaml
|
| 8 |
+
image_root: /data/shz/dataset/coco
|
| 9 |
+
dataset_configs:
|
| 10 |
+
- all
|
| 11 |
+
preprocessing_num_workers: 8
|
| 12 |
+
|
| 13 |
+
# SFT trainer config
|
| 14 |
+
bf16: true
|
| 15 |
+
do_eval: true
|
| 16 |
+
eval_strategy: "no"
|
| 17 |
+
gradient_accumulation_steps: 2
|
| 18 |
+
gradient_checkpointing: true
|
| 19 |
+
gradient_checkpointing_kwargs:
|
| 20 |
+
use_reentrant: false
|
| 21 |
+
hub_model_id: Qwen2.5-VL-3B-Instruct
|
| 22 |
+
hub_strategy: every_save
|
| 23 |
+
learning_rate: 2.0e-05
|
| 24 |
+
log_level: info
|
| 25 |
+
logging_steps: 5
|
| 26 |
+
logging_strategy: steps
|
| 27 |
+
lr_scheduler_type: cosine
|
| 28 |
+
packing: true
|
| 29 |
+
max_seq_length: 4096
|
| 30 |
+
max_steps: -1
|
| 31 |
+
num_train_epochs: 3
|
| 32 |
+
output_dir: /data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-Instruct-SFT
|
| 33 |
+
overwrite_output_dir: true
|
| 34 |
+
per_device_eval_batch_size: 1
|
| 35 |
+
per_device_train_batch_size: 4
|
| 36 |
+
push_to_hub: false
|
| 37 |
+
report_to:
|
| 38 |
+
- wandb
|
| 39 |
+
save_strategy: "no"
|
| 40 |
+
seed: 42
|
| 41 |
+
data_seed: 42
|
| 42 |
+
warmup_ratio: 0.1
|
open-r1-multimodal/data_config/gui_grounding.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets:
|
| 2 |
+
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/showui_desktop_no_position_high_quality_qwen25vl_4028160_attention_0.2_filtered_only_one.json
|
open-r1-multimodal/data_config/rec_internvl.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets:
|
| 2 |
+
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcoco_train.json
|
| 3 |
+
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocop_train.json
|
| 4 |
+
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocog_train.json
|
open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19d1823752455bca732cc85c0f7c6327db602e8140044d946e690abc9bb3ad52
|
| 3 |
+
size 30595146
|
open-r1-multimodal/local_scripts/create_vision_cot_data.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import concurrent.futures
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
from functools import partial
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from typing import Dict, List
|
| 14 |
+
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
import bytedtos
|
| 22 |
+
import seaborn as sns
|
| 23 |
+
import yaml
|
| 24 |
+
from openai import AzureOpenAI
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from pillow_avif import AvifImagePlugin
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
|
| 30 |
+
|
| 31 |
+
Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
|
| 32 |
+
|
| 33 |
+
Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
|
| 34 |
+
|
| 35 |
+
Input Format:
|
| 36 |
+
Original Question: {original_question}
|
| 37 |
+
Original Answer: {original_answer}
|
| 38 |
+
|
| 39 |
+
Output Format:
|
| 40 |
+
Question: [rewrite the question if necessary]
|
| 41 |
+
Answer: [answer with reasoning steps, including calculations where applicable]
|
| 42 |
+
<think>step-by-step reasoning process</think>
|
| 43 |
+
<answer>easy to verify answer</answer>
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_image_data_url(image_input):
|
| 48 |
+
if isinstance(image_input, str) and image_input.startswith("data:"):
|
| 49 |
+
return image_input
|
| 50 |
+
|
| 51 |
+
if isinstance(image_input, str) and image_input.startswith("http"):
|
| 52 |
+
image_input = load_image(image_input)
|
| 53 |
+
|
| 54 |
+
if isinstance(image_input, str):
|
| 55 |
+
image_input = Image.open(image_input)
|
| 56 |
+
|
| 57 |
+
if not isinstance(image_input, Image.Image):
|
| 58 |
+
raise ValueError("Unsupported image input type")
|
| 59 |
+
|
| 60 |
+
if image_input.mode != "RGB":
|
| 61 |
+
image_input = image_input.convert("RGB")
|
| 62 |
+
|
| 63 |
+
buffer = BytesIO()
|
| 64 |
+
image_input.save(buffer, format="JPEG")
|
| 65 |
+
img_bytes = buffer.getvalue()
|
| 66 |
+
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
| 67 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
|
| 71 |
+
if image is None:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
data_url_list = [get_image_data_url(image)]
|
| 75 |
+
client = AzureOpenAI(
|
| 76 |
+
azure_endpoint="YOUR_AZURE_ENDPOINT",
|
| 77 |
+
api_version="2023-07-01-preview",
|
| 78 |
+
api_key="YOUR_API_KEY",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
for attempt in range(max_retries):
|
| 82 |
+
try:
|
| 83 |
+
messages = [
|
| 84 |
+
{
|
| 85 |
+
"role": "system",
|
| 86 |
+
"content": "You are an expert to analyze the image and provide useful information for users.",
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"role": "user",
|
| 90 |
+
"content": [
|
| 91 |
+
{"type": "text", "text": prompt},
|
| 92 |
+
],
|
| 93 |
+
},
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
for data_url in data_url_list:
|
| 97 |
+
messages[1]["content"].insert(
|
| 98 |
+
0, {"type": "image_url", "image_url": {"url": data_url}}
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
response = client.chat.completions.create(
|
| 102 |
+
model="gpt-4o-2024-08-06",
|
| 103 |
+
messages=messages,
|
| 104 |
+
temperature=0.2,
|
| 105 |
+
max_tokens=8192,
|
| 106 |
+
)
|
| 107 |
+
return response.choices[0].message.content
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
if attempt == max_retries - 1:
|
| 111 |
+
raise Exception(
|
| 112 |
+
f"Failed after {max_retries} attempts. Last error: {str(e)}"
|
| 113 |
+
)
|
| 114 |
+
delay = initial_delay * (2**attempt) + random.uniform(
|
| 115 |
+
0, 0.1 * initial_delay * (2**attempt)
|
| 116 |
+
)
|
| 117 |
+
time.sleep(delay)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def process_single_item(example):
|
| 121 |
+
try:
|
| 122 |
+
image_path = example["image_path"]
|
| 123 |
+
formatted_prompt = PROMPT_FORMAT.format(
|
| 124 |
+
original_question=example["question"], original_answer=example["answer"]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
response = gpt4o_query(image_path, formatted_prompt)
|
| 128 |
+
example["gpt4o_response"] = response
|
| 129 |
+
return example
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error processing item: {str(e)}")
|
| 132 |
+
example["gpt4o_response"] = None
|
| 133 |
+
return example
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def main():
|
| 137 |
+
dataset_path = "path/to/your/dataset"
|
| 138 |
+
full_dataset = load_from_disk(dataset_path)
|
| 139 |
+
|
| 140 |
+
processed_dataset = full_dataset.map(
|
| 141 |
+
function=partial(process_single_item),
|
| 142 |
+
num_proc=256,
|
| 143 |
+
desc="Processing dataset with GPT-4o",
|
| 144 |
+
keep_in_memory=True,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
output_path = f"{dataset_path}_processed"
|
| 148 |
+
processed_dataset.save_to_disk(output_path)
|
| 149 |
+
print(f"Processed dataset saved to: {output_path}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
main()
|
open-r1-multimodal/local_scripts/zero3.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
offload_optimizer_device: none
|
| 6 |
+
offload_param_device: none
|
| 7 |
+
zero3_init_flag: true
|
| 8 |
+
zero3_save_16bit_model: true
|
| 9 |
+
zero_stage: 3
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: bf16
|
| 15 |
+
num_machines: 1
|
| 16 |
+
num_processes: 8
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
open-r1-multimodal/setup.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import re
|
| 19 |
+
import shutil
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
from setuptools import find_packages, setup
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
| 26 |
+
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
| 27 |
+
if stale_egg_info.exists():
|
| 28 |
+
print(
|
| 29 |
+
(
|
| 30 |
+
"Warning: {} exists.\n\n"
|
| 31 |
+
"If you recently updated open_r1, this is expected,\n"
|
| 32 |
+
"but it may prevent open_r1 from installing in editable mode.\n\n"
|
| 33 |
+
"This directory is automatically generated by Python's packaging tools.\n"
|
| 34 |
+
"I will remove it now.\n\n"
|
| 35 |
+
"See https://github.com/pypa/pip/issues/5466 for details.\n"
|
| 36 |
+
).format(stale_egg_info)
|
| 37 |
+
)
|
| 38 |
+
shutil.rmtree(stale_egg_info)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
| 42 |
+
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
| 43 |
+
_deps = [
|
| 44 |
+
"accelerate>=1.2.1",
|
| 45 |
+
"bitsandbytes>=0.43.0",
|
| 46 |
+
"black>=24.4.2",
|
| 47 |
+
"datasets>=3.2.0",
|
| 48 |
+
"deepspeed==0.15.4",
|
| 49 |
+
"distilabel[vllm,ray,openai]>=1.5.2",
|
| 50 |
+
"einops>=0.8.0",
|
| 51 |
+
"flake8>=6.0.0",
|
| 52 |
+
"hf_transfer>=0.1.4",
|
| 53 |
+
"huggingface-hub[cli]>=0.19.2,<1.0",
|
| 54 |
+
"isort>=5.12.0",
|
| 55 |
+
"liger_kernel==0.5.2",
|
| 56 |
+
# "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
|
| 57 |
+
"math-verify", # Used for math verification in grpo
|
| 58 |
+
"packaging>=23.0",
|
| 59 |
+
"parameterized>=0.9.0",
|
| 60 |
+
"pytest",
|
| 61 |
+
"safetensors>=0.3.3",
|
| 62 |
+
"sentencepiece>=0.1.99",
|
| 63 |
+
"torch>=2.5.1",
|
| 64 |
+
"transformers>=4.49.0",
|
| 65 |
+
"trl @ git+https://github.com/huggingface/trl.git@main",
|
| 66 |
+
"vllm==0.6.6.post1",
|
| 67 |
+
"wandb>=0.19.1",
|
| 68 |
+
"pillow",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# this is a lookup table with items like:
|
| 72 |
+
#
|
| 73 |
+
# tokenizers: "tokenizers==0.9.4"
|
| 74 |
+
# packaging: "packaging"
|
| 75 |
+
#
|
| 76 |
+
# some of the values are versioned whereas others aren't.
|
| 77 |
+
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def deps_list(*pkgs):
|
| 81 |
+
return [deps[pkg] for pkg in pkgs]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
extras = {}
|
| 85 |
+
extras["tests"] = deps_list("pytest", "parameterized")
|
| 86 |
+
extras["torch"] = deps_list("torch")
|
| 87 |
+
extras["quality"] = deps_list("black", "isort", "flake8")
|
| 88 |
+
# extras["eval"] = deps_list("lighteval", "math-verify")
|
| 89 |
+
extras["eval"] = deps_list("math-verify")
|
| 90 |
+
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
|
| 91 |
+
|
| 92 |
+
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
| 93 |
+
install_requires = [
|
| 94 |
+
deps["accelerate"],
|
| 95 |
+
deps["bitsandbytes"],
|
| 96 |
+
deps["einops"],
|
| 97 |
+
deps["datasets"],
|
| 98 |
+
deps["deepspeed"],
|
| 99 |
+
deps["hf_transfer"],
|
| 100 |
+
deps["huggingface-hub"],
|
| 101 |
+
deps["liger_kernel"],
|
| 102 |
+
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
| 103 |
+
deps["safetensors"],
|
| 104 |
+
deps["sentencepiece"],
|
| 105 |
+
deps["transformers"],
|
| 106 |
+
deps["trl"],
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
setup(
|
| 110 |
+
name="open-r1",
|
| 111 |
+
version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
| 112 |
+
author="The Hugging Face team (past and future)",
|
| 113 |
+
author_email="lewis@huggingface.co",
|
| 114 |
+
description="Open R1",
|
| 115 |
+
# long_description=open("README.md", "r", encoding="utf-8").read(),
|
| 116 |
+
long_description_content_type="text/markdown",
|
| 117 |
+
keywords="llm inference-time compute reasoning",
|
| 118 |
+
license="Apache",
|
| 119 |
+
url="https://github.com/huggingface/open-r1",
|
| 120 |
+
package_dir={"": "src"},
|
| 121 |
+
packages=find_packages("src"),
|
| 122 |
+
zip_safe=False,
|
| 123 |
+
extras_require=extras,
|
| 124 |
+
python_requires=">=3.10.9",
|
| 125 |
+
install_requires=install_requires,
|
| 126 |
+
classifiers=[
|
| 127 |
+
"Development Status :: 3 - Alpha",
|
| 128 |
+
"Intended Audience :: Developers",
|
| 129 |
+
"Intended Audience :: Education",
|
| 130 |
+
"Intended Audience :: Science/Research",
|
| 131 |
+
"License :: OSI Approved :: Apache Software License",
|
| 132 |
+
"Operating System :: OS Independent",
|
| 133 |
+
"Programming Language :: Python :: 3",
|
| 134 |
+
"Programming Language :: Python :: 3.10",
|
| 135 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 136 |
+
],
|
| 137 |
+
)
|
open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
setup.cfg
|
| 3 |
+
setup.py
|
| 4 |
+
src/open_r1/__init__.py
|
| 5 |
+
src/open_r1/configs.py
|
| 6 |
+
src/open_r1/evaluate.py
|
| 7 |
+
src/open_r1/generate.py
|
| 8 |
+
src/open_r1/grpo.py
|
| 9 |
+
src/open_r1/grpo_gui_grounding.py
|
| 10 |
+
src/open_r1/grpo_jsonl.py
|
| 11 |
+
src/open_r1/grpo_rec.py
|
| 12 |
+
src/open_r1/sft.py
|
| 13 |
+
src/open_r1.egg-info/PKG-INFO
|
| 14 |
+
src/open_r1.egg-info/SOURCES.txt
|
| 15 |
+
src/open_r1.egg-info/dependency_links.txt
|
| 16 |
+
src/open_r1.egg-info/not-zip-safe
|
| 17 |
+
src/open_r1.egg-info/requires.txt
|
| 18 |
+
src/open_r1.egg-info/top_level.txt
|
| 19 |
+
src/open_r1/trainer/__init__.py
|
| 20 |
+
src/open_r1/trainer/grpo_config.py
|
| 21 |
+
src/open_r1/trainer/grpo_trainer.py
|
| 22 |
+
src/open_r1/trainer/qwen_grpo_trainer.py
|
| 23 |
+
src/open_r1/trainer/vllm_grpo_trainer.py
|
| 24 |
+
src/open_r1/utils/__init__.py
|
| 25 |
+
src/open_r1/utils/callbacks.py
|
| 26 |
+
src/open_r1/utils/evaluation.py
|
| 27 |
+
src/open_r1/utils/hub.py
|
| 28 |
+
src/open_r1/utils/math.py
|
| 29 |
+
src/open_r1/vlm_modules/__init__.py
|
| 30 |
+
src/open_r1/vlm_modules/internvl_module.py
|
| 31 |
+
src/open_r1/vlm_modules/qwen_module.py
|
| 32 |
+
src/open_r1/vlm_modules/vlm_module.py
|
open-r1-multimodal/src/open_r1.egg-info/not-zip-safe
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
open-r1-multimodal/src/open_r1/grpo.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# import debugpy
|
| 16 |
+
# try:
|
| 17 |
+
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
| 18 |
+
# debugpy.listen(("localhost", 9501))
|
| 19 |
+
# print("Waiting for debugger attach")
|
| 20 |
+
# debugpy.wait_for_client()
|
| 21 |
+
# except Exception as e:
|
| 22 |
+
# pass
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
from datasets import load_dataset, load_from_disk
|
| 31 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 32 |
+
|
| 33 |
+
from math_verify import parse, verify
|
| 34 |
+
from open_r1.trainer import VLMGRPOTrainer
|
| 35 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 40 |
+
"""
|
| 41 |
+
Script arguments for the GRPO training script.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
reward_funcs (`list[str]`):
|
| 45 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
reward_funcs: list[str] = field(
|
| 49 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 50 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 51 |
+
)
|
| 52 |
+
max_pixels: Optional[int] = field(
|
| 53 |
+
default=12845056,
|
| 54 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 55 |
+
)
|
| 56 |
+
min_pixels: Optional[int] = field(
|
| 57 |
+
default=3136,
|
| 58 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 63 |
+
"""Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
|
| 64 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 65 |
+
rewards = []
|
| 66 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 67 |
+
for content, sol in zip(contents, solution):
|
| 68 |
+
reward = 0.0
|
| 69 |
+
# Try symbolic verification first
|
| 70 |
+
try:
|
| 71 |
+
answer = parse(content)
|
| 72 |
+
if float(verify(answer, parse(sol))) > 0:
|
| 73 |
+
reward = 1.0
|
| 74 |
+
except Exception:
|
| 75 |
+
pass # Continue to next verification method if this fails
|
| 76 |
+
|
| 77 |
+
# If symbolic verification failed, try string matching
|
| 78 |
+
if reward == 0.0:
|
| 79 |
+
try:
|
| 80 |
+
# Extract answer from solution if it has think/answer tags
|
| 81 |
+
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 82 |
+
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 83 |
+
|
| 84 |
+
# Extract answer from content if it has think/answer tags
|
| 85 |
+
content_match = re.search(r'<answer>(.*?)</answer>', content)
|
| 86 |
+
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 87 |
+
|
| 88 |
+
# Compare the extracted answers
|
| 89 |
+
if student_answer == ground_truth:
|
| 90 |
+
reward = 1.0
|
| 91 |
+
except Exception:
|
| 92 |
+
pass # Keep reward as 0.0 if both methods fail
|
| 93 |
+
|
| 94 |
+
rewards.append(reward)
|
| 95 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 96 |
+
log_path = os.getenv("LOG_PATH")
|
| 97 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 98 |
+
with open(log_path, "a") as f:
|
| 99 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 100 |
+
f.write(f"Content: {content}\n")
|
| 101 |
+
f.write(f"Solution: {sol}\n")
|
| 102 |
+
return rewards
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def format_reward(completions, **kwargs):
|
| 106 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 107 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 108 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 109 |
+
matches = [re.match(pattern, content) for content in completion_contents]
|
| 110 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
reward_funcs_registry = {
|
| 114 |
+
"accuracy": accuracy_reward,
|
| 115 |
+
"format": format_reward,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
SYSTEM_PROMPT = (
|
| 119 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 120 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 121 |
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 122 |
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main(script_args, training_args, model_args):
|
| 127 |
+
# Get reward functions
|
| 128 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 129 |
+
print("reward_funcs:", reward_funcs)
|
| 130 |
+
|
| 131 |
+
# Load the dataset
|
| 132 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Format into conversation
|
| 136 |
+
def make_conversation(example):
|
| 137 |
+
return {
|
| 138 |
+
"prompt": [
|
| 139 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 140 |
+
{"role": "user", "content": example["problem"]},
|
| 141 |
+
],
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# def make_conversation_image(example):
|
| 145 |
+
# return {
|
| 146 |
+
# "prompt": [
|
| 147 |
+
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 148 |
+
# {
|
| 149 |
+
# "role": "user",
|
| 150 |
+
# "content": [
|
| 151 |
+
# {"type": "image"},
|
| 152 |
+
# {"type": "text", "text": example["problem"]},
|
| 153 |
+
# ],
|
| 154 |
+
# },
|
| 155 |
+
# ],
|
| 156 |
+
# }
|
| 157 |
+
|
| 158 |
+
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
| 159 |
+
|
| 160 |
+
def make_conversation_image(example):
|
| 161 |
+
return {
|
| 162 |
+
"prompt": [
|
| 163 |
+
{
|
| 164 |
+
"role": "user",
|
| 165 |
+
"content": [
|
| 166 |
+
{"type": "image"},
|
| 167 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 168 |
+
],
|
| 169 |
+
},
|
| 170 |
+
],
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if "image" in dataset[script_args.dataset_train_split].features:
|
| 175 |
+
print("has image in dataset")
|
| 176 |
+
dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
|
| 177 |
+
# dataset = dataset.remove_columns(["original_question", "original_answer"])
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
print("no image in dataset")
|
| 181 |
+
dataset = dataset.map(make_conversation)
|
| 182 |
+
dataset = dataset.remove_columns("messages")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
trainer_cls = VLMGRPOTrainer
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Initialize the GRPO trainer
|
| 189 |
+
trainer = trainer_cls(
|
| 190 |
+
model=model_args.model_name_or_path,
|
| 191 |
+
reward_funcs=reward_funcs,
|
| 192 |
+
args=training_args,
|
| 193 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 194 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 195 |
+
peft_config=get_peft_config(model_args),
|
| 196 |
+
attn_implementation=model_args.attn_implementation,
|
| 197 |
+
max_pixels=script_args.max_pixels,
|
| 198 |
+
min_pixels=script_args.min_pixels,
|
| 199 |
+
torch_dtype=model_args.torch_dtype,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Train and push the model to the Hub
|
| 203 |
+
trainer.train()
|
| 204 |
+
|
| 205 |
+
# Save and push to hub
|
| 206 |
+
trainer.save_model(training_args.output_dir)
|
| 207 |
+
if training_args.push_to_hub:
|
| 208 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 213 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 214 |
+
main(script_args, training_args, model_args)
|
open-r1-multimodal/src/open_r1/sft.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Supervised fine-tuning script for decoder language models.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
|
| 20 |
+
# One 1 node of 8 x H100s
|
| 21 |
+
accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
|
| 22 |
+
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
| 23 |
+
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
|
| 24 |
+
--learning_rate 2.0e-5 \
|
| 25 |
+
--num_train_epochs 1 \
|
| 26 |
+
--packing \
|
| 27 |
+
--max_seq_length 4096 \
|
| 28 |
+
--per_device_train_batch_size 4 \
|
| 29 |
+
--gradient_accumulation_steps 4 \
|
| 30 |
+
--gradient_checkpointing \
|
| 31 |
+
--bf16 \
|
| 32 |
+
--logging_steps 5 \
|
| 33 |
+
--eval_strategy steps \
|
| 34 |
+
--eval_steps 100 \
|
| 35 |
+
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import logging
|
| 39 |
+
import os
|
| 40 |
+
import sys
|
| 41 |
+
|
| 42 |
+
import datasets
|
| 43 |
+
import torch
|
| 44 |
+
from torch.utils.data import Dataset
|
| 45 |
+
import transformers
|
| 46 |
+
from datasets import load_dataset
|
| 47 |
+
from transformers import AutoTokenizer, set_seed, AutoProcessor
|
| 48 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 49 |
+
from open_r1.configs import SFTConfig
|
| 50 |
+
from open_r1.utils.callbacks import get_callbacks
|
| 51 |
+
import yaml
|
| 52 |
+
import json
|
| 53 |
+
import math
|
| 54 |
+
import random
|
| 55 |
+
from PIL import Image
|
| 56 |
+
|
| 57 |
+
from trl import (
|
| 58 |
+
ModelConfig,
|
| 59 |
+
ScriptArguments,
|
| 60 |
+
SFTTrainer,
|
| 61 |
+
TrlParser,
|
| 62 |
+
get_kbit_device_map,
|
| 63 |
+
get_peft_config,
|
| 64 |
+
get_quantization_config,
|
| 65 |
+
)
|
| 66 |
+
from dataclasses import field
|
| 67 |
+
from qwen_vl_utils import process_vision_info
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
from dataclasses import dataclass
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class SFTScriptArguments(ScriptArguments):
|
| 73 |
+
image_root: str = field(default=None, metadata={"help": "The root directory of the image."})
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
processor = None
|
| 77 |
+
|
| 78 |
+
class LazySupervisedDataset(Dataset):
|
| 79 |
+
def __init__(self, data_path: str, script_args: ScriptArguments):
|
| 80 |
+
super(LazySupervisedDataset, self).__init__()
|
| 81 |
+
self.script_args = script_args
|
| 82 |
+
self.list_data_dict = []
|
| 83 |
+
|
| 84 |
+
if data_path.endswith(".yaml"):
|
| 85 |
+
with open(data_path, "r") as file:
|
| 86 |
+
yaml_data = yaml.safe_load(file)
|
| 87 |
+
datasets = yaml_data.get("datasets")
|
| 88 |
+
# file should be in the format of:
|
| 89 |
+
# datasets:
|
| 90 |
+
# - json_path: xxxx1.json
|
| 91 |
+
# sampling_strategy: first:1000
|
| 92 |
+
# - json_path: xxxx2.json
|
| 93 |
+
# sampling_strategy: end:3000
|
| 94 |
+
# - json_path: xxxx3.json
|
| 95 |
+
# sampling_strategy: random:999
|
| 96 |
+
|
| 97 |
+
for data in datasets:
|
| 98 |
+
json_path = data.get("json_path")
|
| 99 |
+
sampling_strategy = data.get("sampling_strategy", "all")
|
| 100 |
+
sampling_number = None
|
| 101 |
+
|
| 102 |
+
if json_path.endswith(".jsonl"):
|
| 103 |
+
cur_data_dict = []
|
| 104 |
+
with open(json_path, "r") as json_file:
|
| 105 |
+
for line in json_file:
|
| 106 |
+
cur_data_dict.append(json.loads(line.strip()))
|
| 107 |
+
elif json_path.endswith(".json"):
|
| 108 |
+
with open(json_path, "r") as json_file:
|
| 109 |
+
cur_data_dict = json.load(json_file)
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(f"Unsupported file type: {json_path}")
|
| 112 |
+
|
| 113 |
+
if ":" in sampling_strategy:
|
| 114 |
+
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
| 115 |
+
if "%" in sampling_number:
|
| 116 |
+
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
| 117 |
+
else:
|
| 118 |
+
sampling_number = int(sampling_number)
|
| 119 |
+
|
| 120 |
+
# Apply the sampling strategy
|
| 121 |
+
if sampling_strategy == "first" and sampling_number is not None:
|
| 122 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
| 123 |
+
elif sampling_strategy == "end" and sampling_number is not None:
|
| 124 |
+
cur_data_dict = cur_data_dict[-sampling_number:]
|
| 125 |
+
elif sampling_strategy == "random" and sampling_number is not None:
|
| 126 |
+
random.shuffle(cur_data_dict)
|
| 127 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
| 128 |
+
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
| 129 |
+
self.list_data_dict.extend(cur_data_dict)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Unsupported file type: {data_path}")
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return len(self.list_data_dict)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, i):
|
| 137 |
+
# Format into conversation
|
| 138 |
+
def make_conversation_image(example):
|
| 139 |
+
image_root = self.script_args.image_root
|
| 140 |
+
# print(111, image_root)
|
| 141 |
+
# print(222, example['image'])
|
| 142 |
+
image_path = os.path.join(image_root, example['image'])
|
| 143 |
+
x1, y1, x2, y2 = example["solution"]
|
| 144 |
+
normal_caption = example["normal_caption"]
|
| 145 |
+
return [
|
| 146 |
+
{
|
| 147 |
+
"role": "user",
|
| 148 |
+
"content": [
|
| 149 |
+
{"type": "image", "image": f"file://{image_path}"},
|
| 150 |
+
{"type": "text", "text": example["problem"]},
|
| 151 |
+
],
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"role": "assistant",
|
| 155 |
+
"content": f'```json\n[\n\t{{"bbox_2d": [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}], "label": "{normal_caption}"}}\n]\n```',
|
| 156 |
+
}
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
example = self.list_data_dict[i]
|
| 160 |
+
example["messages"] = make_conversation_image(example)
|
| 161 |
+
return example
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def collate_fn(examples):
|
| 166 |
+
texts = [
|
| 167 |
+
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
|
| 168 |
+
for example in examples
|
| 169 |
+
]
|
| 170 |
+
image_inputs = []
|
| 171 |
+
for example in examples:
|
| 172 |
+
imgs, vids = process_vision_info(example["messages"])
|
| 173 |
+
image_inputs.append(imgs)
|
| 174 |
+
batch = processor(
|
| 175 |
+
text=texts,
|
| 176 |
+
images=image_inputs,
|
| 177 |
+
return_tensors="pt",
|
| 178 |
+
padding=True,
|
| 179 |
+
)
|
| 180 |
+
labels = batch["input_ids"].clone()
|
| 181 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
| 182 |
+
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
| 183 |
+
labels[labels == image_token_id] = -100
|
| 184 |
+
batch["labels"] = labels
|
| 185 |
+
|
| 186 |
+
return batch
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main(script_args, training_args, model_args):
|
| 190 |
+
# Set seed for reproducibility
|
| 191 |
+
set_seed(training_args.seed)
|
| 192 |
+
|
| 193 |
+
###############
|
| 194 |
+
# Setup logging
|
| 195 |
+
###############
|
| 196 |
+
logging.basicConfig(
|
| 197 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 198 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 199 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 200 |
+
)
|
| 201 |
+
log_level = training_args.get_process_log_level()
|
| 202 |
+
logger.setLevel(log_level)
|
| 203 |
+
datasets.utils.logging.set_verbosity(log_level)
|
| 204 |
+
transformers.utils.logging.set_verbosity(log_level)
|
| 205 |
+
transformers.utils.logging.enable_default_handler()
|
| 206 |
+
transformers.utils.logging.enable_explicit_format()
|
| 207 |
+
|
| 208 |
+
# Log on each process a small summary
|
| 209 |
+
logger.warning(
|
| 210 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 211 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 212 |
+
)
|
| 213 |
+
logger.info(f"Model parameters {model_args}")
|
| 214 |
+
logger.info(f"Script parameters {script_args}")
|
| 215 |
+
logger.info(f"Data parameters {training_args}")
|
| 216 |
+
|
| 217 |
+
# Check for last checkpoint
|
| 218 |
+
last_checkpoint = None
|
| 219 |
+
if os.path.isdir(training_args.output_dir):
|
| 220 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 221 |
+
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 222 |
+
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
| 223 |
+
|
| 224 |
+
################
|
| 225 |
+
# Load datasets
|
| 226 |
+
################
|
| 227 |
+
|
| 228 |
+
dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
|
| 229 |
+
|
| 230 |
+
################
|
| 231 |
+
# Load tokenizer
|
| 232 |
+
################
|
| 233 |
+
global processor
|
| 234 |
+
if "vl" in model_args.model_name_or_path.lower():
|
| 235 |
+
processor = AutoProcessor.from_pretrained(
|
| 236 |
+
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
| 237 |
+
)
|
| 238 |
+
logger.info("Using AutoProcessor for vision-language model.")
|
| 239 |
+
else:
|
| 240 |
+
processor = AutoTokenizer.from_pretrained(
|
| 241 |
+
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
|
| 242 |
+
)
|
| 243 |
+
logger.info("Using AutoTokenizer for text-only model.")
|
| 244 |
+
if hasattr(processor, "pad_token") and processor.pad_token is None:
|
| 245 |
+
processor.pad_token = processor.eos_token
|
| 246 |
+
elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
|
| 247 |
+
processor.tokenizer.pad_token = processor.tokenizer.eos_token
|
| 248 |
+
|
| 249 |
+
###################
|
| 250 |
+
# Model init kwargs
|
| 251 |
+
###################
|
| 252 |
+
logger.info("*** Initializing model kwargs ***")
|
| 253 |
+
torch_dtype = (
|
| 254 |
+
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
| 255 |
+
)
|
| 256 |
+
quantization_config = get_quantization_config(model_args)
|
| 257 |
+
model_kwargs = dict(
|
| 258 |
+
revision=model_args.model_revision,
|
| 259 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 260 |
+
attn_implementation=model_args.attn_implementation,
|
| 261 |
+
torch_dtype=torch_dtype,
|
| 262 |
+
use_cache=False if training_args.gradient_checkpointing else True,
|
| 263 |
+
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
| 264 |
+
quantization_config=quantization_config,
|
| 265 |
+
)
|
| 266 |
+
# training_args.model_init_kwargs = model_kwargs
|
| 267 |
+
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
| 268 |
+
if "Qwen2-VL" in model_args.model_name_or_path:
|
| 269 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 270 |
+
model_args.model_name_or_path, **model_kwargs
|
| 271 |
+
)
|
| 272 |
+
elif "Qwen2.5-VL" in model_args.model_name_or_path:
|
| 273 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 274 |
+
model_args.model_name_or_path, **model_kwargs
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
raise ValueError(f"Unsupported model: {model_args.model_name_or_path}")
|
| 278 |
+
############################
|
| 279 |
+
# Initialize the SFT Trainer
|
| 280 |
+
############################
|
| 281 |
+
training_args.dataset_kwargs = {
|
| 282 |
+
"skip_prepare_dataset": True,
|
| 283 |
+
}
|
| 284 |
+
training_args.remove_unused_columns = False
|
| 285 |
+
trainer = SFTTrainer(
|
| 286 |
+
model=model,
|
| 287 |
+
args=training_args,
|
| 288 |
+
train_dataset=dataset,
|
| 289 |
+
eval_dataset=None,
|
| 290 |
+
processing_class=processor.tokenizer,
|
| 291 |
+
data_collator=collate_fn,
|
| 292 |
+
peft_config=get_peft_config(model_args),
|
| 293 |
+
callbacks=get_callbacks(training_args, model_args),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
###############
|
| 297 |
+
# Training loop
|
| 298 |
+
###############
|
| 299 |
+
logger.info("*** Train ***")
|
| 300 |
+
checkpoint = None
|
| 301 |
+
if training_args.resume_from_checkpoint is not None:
|
| 302 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 303 |
+
elif last_checkpoint is not None:
|
| 304 |
+
checkpoint = last_checkpoint
|
| 305 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 306 |
+
metrics = train_result.metrics
|
| 307 |
+
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
| 308 |
+
trainer.log_metrics("train", metrics)
|
| 309 |
+
trainer.save_metrics("train", metrics)
|
| 310 |
+
trainer.save_state()
|
| 311 |
+
|
| 312 |
+
##################################
|
| 313 |
+
# Save model and create model card
|
| 314 |
+
##################################
|
| 315 |
+
logger.info("*** Save model ***")
|
| 316 |
+
trainer.save_model(training_args.output_dir)
|
| 317 |
+
logger.info(f"Model saved to {training_args.output_dir}")
|
| 318 |
+
|
| 319 |
+
# Save everything else on main process
|
| 320 |
+
kwargs = {
|
| 321 |
+
"finetuned_from": model_args.model_name_or_path,
|
| 322 |
+
"dataset": list(script_args.dataset_name),
|
| 323 |
+
"dataset_tags": list(script_args.dataset_name),
|
| 324 |
+
"tags": ["open-r1"],
|
| 325 |
+
}
|
| 326 |
+
if trainer.accelerator.is_main_process:
|
| 327 |
+
trainer.create_model_card(**kwargs)
|
| 328 |
+
# Restore k,v cache for fast inference
|
| 329 |
+
trainer.model.config.use_cache = True
|
| 330 |
+
trainer.model.config.save_pretrained(training_args.output_dir)
|
| 331 |
+
#############
|
| 332 |
+
# push to hub
|
| 333 |
+
#############
|
| 334 |
+
|
| 335 |
+
if training_args.push_to_hub:
|
| 336 |
+
logger.info("Pushing to hub...")
|
| 337 |
+
trainer.push_to_hub(**kwargs)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
if __name__ == "__main__":
|
| 343 |
+
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
| 344 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 345 |
+
print(script_args)
|
| 346 |
+
main(script_args, training_args, model_args)
|
open-r1-multimodal/src/open_r1/trainer/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .grpo_trainer import VLMGRPOTrainer
|
| 2 |
+
from .grpo_config import GRPOConfig
|
| 3 |
+
from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
|
| 4 |
+
from .qwen_grpo_trainer import Qwen2VLGRPOTrainer
|
| 5 |
+
__all__ = ["VLMGRPOTrainer",'Qwen2VLGRPOVLLMTrainer', "Qwen2VLGRPOTrainer"]
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/vllm_grpo_trainer.cpython-310.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/utils/__pycache__/math.cpython-310.pyc.139714633805856
ADDED
|
Binary file (3.88 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/utils/__pycache__/math.cpython-310.pyc.140170314805280
ADDED
|
Binary file (3.88 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/vlm_modules/__pycache__/internvl_module.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/vlm_modules/__pycache__/qwen_module.cpython-310.pyc
ADDED
|
Binary file (9.36 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
|
| 2 |
+
from typing import Dict, Any, Union
|
| 3 |
+
from trl.data_utils import maybe_apply_chat_template
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
| 7 |
+
|
| 8 |
+
class Qwen2VLModule(VLMBaseModule):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
def get_vlm_key(self):
|
| 13 |
+
return "qwen"
|
| 14 |
+
|
| 15 |
+
def get_model_class(self, model_id: str, model_init_kwargs: dict):
|
| 16 |
+
if "Qwen2-VL" in model_id:
|
| 17 |
+
model_cls = Qwen2VLForConditionalGeneration
|
| 18 |
+
elif "Qwen2.5-VL" in model_id:
|
| 19 |
+
model_cls = Qwen2_5_VLForConditionalGeneration
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f"Unsupported model: {model_id}")
|
| 22 |
+
return model_cls
|
| 23 |
+
|
| 24 |
+
def post_model_init(self, model, processing_class):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def get_processing_class(self):
|
| 28 |
+
return AutoProcessor
|
| 29 |
+
|
| 30 |
+
def get_vision_modules_keywords(self):
|
| 31 |
+
return ['visual']
|
| 32 |
+
|
| 33 |
+
def get_custom_multimodal_keywords(self):
|
| 34 |
+
return ['pixel_values', 'image_grid_thw']
|
| 35 |
+
|
| 36 |
+
def get_non_generate_params(self):
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
def get_custom_processing_keywords(self):
|
| 40 |
+
return ['max_pixels', 'min_pixels']
|
| 41 |
+
|
| 42 |
+
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
|
| 43 |
+
prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
|
| 44 |
+
return prompts_text
|
| 45 |
+
|
| 46 |
+
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
|
| 47 |
+
# FIXME
|
| 48 |
+
# This could only process pure-multimodal or pure-text inputs
|
| 49 |
+
if len(images) > 0:
|
| 50 |
+
prompt_inputs = processing_class(
|
| 51 |
+
text=prompts_text,
|
| 52 |
+
images=images,
|
| 53 |
+
return_tensors=return_tensors,
|
| 54 |
+
padding=padding,
|
| 55 |
+
padding_side=padding_side,
|
| 56 |
+
add_special_tokens=add_special_tokens)
|
| 57 |
+
else:
|
| 58 |
+
prompt_inputs = processing_class(
|
| 59 |
+
text=prompts_text,
|
| 60 |
+
return_tensors=return_tensors,
|
| 61 |
+
padding=padding,
|
| 62 |
+
padding_side=padding_side,
|
| 63 |
+
add_special_tokens=add_special_tokens)
|
| 64 |
+
return prompt_inputs
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_question_template(task_type: str):
|
| 68 |
+
match task_type:
|
| 69 |
+
case "rec":
|
| 70 |
+
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
|
| 71 |
+
case _:
|
| 72 |
+
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def format_reward_rec(completions, **kwargs):
|
| 76 |
+
"""Check if the Qwen model output matches a specific format."""
|
| 77 |
+
import re
|
| 78 |
+
|
| 79 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
|
| 80 |
+
pattern = r"<tool_call>.*?\{.*\[\d+,\s*\d+\].*\}.*?</tool_call>"
|
| 81 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 82 |
+
print(completion_contents)
|
| 83 |
+
print('-'*100)
|
| 84 |
+
# print(completion_contents)
|
| 85 |
+
# print('-'*100)
|
| 86 |
+
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
|
| 87 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def format_reward(completions, **kwargs):
|
| 91 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?\[.*?{\"bbox_2d\":\s*\[\s*\d+,\s*\d+,\s*\d+,\s*\d+\s*\]\s*,\s*\"label\":\s*\".*?\"\s*}.*?\].*?</answer>"
|
| 92 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 93 |
+
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
|
| 94 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def point_reward(completions, solution, **kwargs):
|
| 98 |
+
"""Calculate reward based on whether the predicted point is inside the bounding box and its distance from the box center."""
|
| 99 |
+
import re
|
| 100 |
+
import json
|
| 101 |
+
import os
|
| 102 |
+
from datetime import datetime
|
| 103 |
+
import math
|
| 104 |
+
|
| 105 |
+
# 从每个 completion 中提取 content
|
| 106 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 107 |
+
rewards = []
|
| 108 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 109 |
+
|
| 110 |
+
# 遍历每个 content 和对应的 solution
|
| 111 |
+
for content, sol in zip(contents, solution):
|
| 112 |
+
reward = 0.0
|
| 113 |
+
log_details = None
|
| 114 |
+
try:
|
| 115 |
+
# 使用正则表达式提取 <tool_call> 标签中的内容
|
| 116 |
+
tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 117 |
+
if tool_call_match:
|
| 118 |
+
tool_call_content = tool_call_match.group(1).strip()
|
| 119 |
+
# 解析 JSON
|
| 120 |
+
tool_call_json = json.loads(tool_call_content)
|
| 121 |
+
arguments = tool_call_json.get("arguments", {})
|
| 122 |
+
coordinate = arguments.get("coordinate", None)
|
| 123 |
+
# 检查坐标是否是一个长度为 2 的列表
|
| 124 |
+
if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 125 |
+
x, y = coordinate
|
| 126 |
+
# 确保 x 和 y 是数值类型
|
| 127 |
+
if isinstance(x, (int, float)) and isinstance(y, (int, float)):
|
| 128 |
+
# 提取边界框和图像尺寸
|
| 129 |
+
box = sol[:4] # [x_min, y_min, x_max, y_max]
|
| 130 |
+
img_width, img_height = sol[4], sol[5]
|
| 131 |
+
|
| 132 |
+
# 检查点是否在边界框内
|
| 133 |
+
if box[0] <= x <= box[2] and box[1] <= y <= box[3]:
|
| 134 |
+
base_reward = 1.0
|
| 135 |
+
else:
|
| 136 |
+
base_reward = 0.0
|
| 137 |
+
|
| 138 |
+
# 计算边界框中心
|
| 139 |
+
cx = (box[0] + box[2]) / 2
|
| 140 |
+
cy = (box[1] + box[3]) / 2
|
| 141 |
+
|
| 142 |
+
# 归一化坐标
|
| 143 |
+
nx = x / img_width
|
| 144 |
+
ny = y / img_height
|
| 145 |
+
ncx = cx / img_width
|
| 146 |
+
ncy = cy / img_height
|
| 147 |
+
|
| 148 |
+
# 计算边界框中心到图像四个角的归一化距离
|
| 149 |
+
d1 = math.sqrt((ncx - 0)**2 + (ncy - 0)**2)
|
| 150 |
+
d2 = math.sqrt((ncx - 1)**2 + (ncy - 0)**2)
|
| 151 |
+
d3 = math.sqrt((ncx - 0)**2 + (ncy - 1)**2)
|
| 152 |
+
d4 = math.sqrt((ncx - 1)**2 + (ncy - 1)**2)
|
| 153 |
+
max_d = max(d1, d2, d3, d4)
|
| 154 |
+
|
| 155 |
+
# 计算点到中心的归一化距离
|
| 156 |
+
d = math.sqrt((nx - ncx)**2 + (ny - ncy)**2)
|
| 157 |
+
d_normalized = d / max_d if max_d > 0 else 0
|
| 158 |
+
decay_term = 1 - d_normalized**2 if d <= 1 else 0
|
| 159 |
+
|
| 160 |
+
# 总奖励
|
| 161 |
+
reward = base_reward + decay_term
|
| 162 |
+
|
| 163 |
+
# 为日志记录准备数据
|
| 164 |
+
log_details = {
|
| 165 |
+
"extracted_coordinate": [x, y],
|
| 166 |
+
"base_reward": base_reward,
|
| 167 |
+
"decay_term": decay_term
|
| 168 |
+
}
|
| 169 |
+
except Exception as e:
|
| 170 |
+
# 如果解析失败或发生异常,reward 保持为 0.0
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
rewards.append(reward)
|
| 174 |
+
|
| 175 |
+
# 如果启用 DEBUG_MODE,则记录详细信息
|
| 176 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 177 |
+
log_path = os.getenv("LOG_PATH")
|
| 178 |
+
with open(log_path, "a", encoding='utf-8') as f:
|
| 179 |
+
f.write(f"------------- {current_time} Point-in-box reward: {reward} -------------\n")
|
| 180 |
+
f.write(f"Content: {content}\n")
|
| 181 |
+
f.write(f"Solution box: {sol[:4]}\n")
|
| 182 |
+
f.write(f"Image size: {sol[4]}x{sol[5]}\n")
|
| 183 |
+
if log_details:
|
| 184 |
+
f.write(f"Extracted coordinate: {log_details['extracted_coordinate']}\n")
|
| 185 |
+
f.write(f"Base reward: {log_details['base_reward']}\n")
|
| 186 |
+
f.write(f"Decay term: {log_details['decay_term']}\n")
|
| 187 |
+
else:
|
| 188 |
+
f.write("Failed to extract coordinate\n")
|
| 189 |
+
|
| 190 |
+
return rewards
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def iou_reward(completions, solution, **kwargs):
|
| 194 |
+
"""Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
|
| 195 |
+
import re
|
| 196 |
+
import os
|
| 197 |
+
from datetime import datetime
|
| 198 |
+
def iou(box1, box2):
|
| 199 |
+
inter_x1 = max(box1[0], box2[0])
|
| 200 |
+
inter_y1 = max(box1[1], box2[1])
|
| 201 |
+
inter_x2 = min(box1[2]-1, box2[2]-1)
|
| 202 |
+
inter_y2 = min(box1[3]-1, box2[3]-1)
|
| 203 |
+
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
| 204 |
+
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
| 205 |
+
else:
|
| 206 |
+
inter = 0
|
| 207 |
+
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
| 208 |
+
return float(inter)/union
|
| 209 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 210 |
+
rewards = []
|
| 211 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 212 |
+
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
| 213 |
+
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
|
| 214 |
+
for content, sol in zip(contents, solution):
|
| 215 |
+
reward = 0.0
|
| 216 |
+
# Try symbolic verification first
|
| 217 |
+
try:
|
| 218 |
+
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 219 |
+
if content_answer_match:
|
| 220 |
+
content_answer = content_answer_match.group(1).strip()
|
| 221 |
+
bbox_match = re.search(bbox_pattern, content_answer)
|
| 222 |
+
if bbox_match:
|
| 223 |
+
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
| 224 |
+
# if iou(bbox, sol) > 0.5:
|
| 225 |
+
# reward = 1.0
|
| 226 |
+
reward = iou(bbox, sol)
|
| 227 |
+
except Exception:
|
| 228 |
+
pass # Continue to next verification method if this fails
|
| 229 |
+
|
| 230 |
+
rewards.append(reward)
|
| 231 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 232 |
+
log_path = os.getenv("LOG_PATH")
|
| 233 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 234 |
+
with open(log_path, "a", encoding='utf-8') as f:
|
| 235 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 236 |
+
f.write(f"Content: {content}\n")
|
| 237 |
+
f.write(f"Solution: {sol}\n")
|
| 238 |
+
return rewards
|