XinBB commited on
Commit
22f6d8c
·
verified ·
1 Parent(s): ee53828

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json filter=lfs diff=lfs merge=lfs -text
eval/test_grounding_r1_nothink_ss.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ from pprint import pprint
9
+ import random
10
+ from PIL import Image
11
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import argparse
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
19
+
20
+ def setup_distributed():
21
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
22
+ torch.cuda.set_device(local_rank)
23
+
24
+ dist.init_process_group(backend="nccl")
25
+
26
+ world_size = dist.get_world_size()
27
+ rank = dist.get_rank()
28
+
29
+ return local_rank, world_size, rank
30
+
31
+ local_rank, world_size, rank = setup_distributed()
32
+ device = f"cuda:{local_rank}"
33
+ print(f"Process {rank} using {device}")
34
+
35
+ steps = 3800
36
+ if rank == 0:
37
+ print("Steps: ", steps)
38
+ #RUN_NAME = "base"
39
+ RUN_NAME = "Qwen2.5-VL-7B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_filtered_only_one_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e"
40
+ #MODEL_PATH="/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-7B-Instruct"
41
+ MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
42
+ OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
43
+
44
+ BSZ=32
45
+ DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot/metadata"
46
+
47
+ TEST_DATASETS = ['hf_test_full']
48
+ IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot/images"
49
+
50
+
51
+ # TEST_DATASETS = ['lisa_test']
52
+ # IMAGE_ROOT = "/data10/shz/dataset/lisa"
53
+
54
+
55
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
56
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
+ MODEL_PATH,
58
+ torch_dtype=torch.bfloat16,
59
+ attn_implementation="flash_attention_2",
60
+ device_map={"": local_rank},
61
+ )
62
+ # default processer
63
+ processor = AutoProcessor.from_pretrained(MODEL_PATH,max_pixels=2007040,min_pixels=3136)
64
+ # processor.image_processor.min_pixels=3136
65
+ # processor.image_processor.max_pixels=2007040
66
+ print(processor.image_processor.min_pixels)
67
+ print(processor.image_processor.max_pixels)
68
+ # def extract_point_answer(content):
69
+ # # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
70
+ # answer_tag_pattern = r'<answer>(.*?)</answer>'
71
+ # content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
72
+ # if content_answer_match:
73
+ # content_answer = content_answer_match.group(1).strip()
74
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
75
+ # if tool_call_match:
76
+ # tool_call_content = tool_call_match.group(1).strip()
77
+ # # 解析 JSON
78
+ # tool_call_json = json.loads(tool_call_content)
79
+ # arguments = tool_call_json.get("arguments", {})
80
+ # coordinate = arguments.get("coordinate", None)
81
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
82
+ # x, y = coordinate
83
+ # extracted_coordinate = [x, y]
84
+ # return extracted_coordinate
85
+ # return [0, 0]
86
+
87
+
88
+ # def extract_point_answer(content):
89
+ # # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
90
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
91
+ # if tool_call_match:
92
+ # tool_call_content = tool_call_match.group(1).strip()
93
+ # # 首先尝试将 tool_call_content 解析为 JSON
94
+ # try:
95
+ # tool_call_json = json.loads(tool_call_content)
96
+ # print(tool_call_json)
97
+ # arguments = tool_call_json.get("arguments", {})
98
+ # coordinate = arguments.get("coordinate", None)
99
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
100
+ # try:
101
+ # x = float(coordinate[0])
102
+ # y = float(coordinate[1])
103
+ # return [x, y]
104
+ # except (ValueError, TypeError):
105
+ # pass # 如果转换失败,继续尝试正则提取
106
+ # except json.JSONDecodeError:
107
+ # pass # 如果 JSON 解析失败,继续尝试正则提取
108
+ # # 回退到正则表达式提取两个数字
109
+ # numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
110
+ # if len(numbers) >= 2:
111
+ # x = float(numbers[-2])
112
+ # y = float(numbers[-1])
113
+ # return [x, y]
114
+ # return [0, 0]
115
+
116
+
117
+
118
+ def extract_point_answer(content):
119
+ # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
120
+ tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
121
+ if tool_call_match:
122
+ tool_call_content = tool_call_match.group(1).strip()
123
+ # 首先尝试将 tool_call_content 解析为 JSON
124
+ try:
125
+ numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
126
+ if len(numbers) >= 2:
127
+ x = float(numbers[-2])
128
+ y = float(numbers[-1])
129
+ return [x, y]
130
+ except json.JSONDecodeError:
131
+ pass # 如果 JSON 解析失败,继续尝试正则提取
132
+ # 回退到正则表达式提取两个数字
133
+ return [0, 0]
134
+
135
+ def point_in_box(point, box):
136
+ x,y = point
137
+ if box[0] <= x < box[2] and box[1] <= y < box[3]:
138
+ return 1
139
+ else:
140
+ return 0
141
+
142
+ num_samples = 2000
143
+ num_all_sample = 0
144
+ num_desktop_sample = 0
145
+ num_mobile_sample = 0
146
+ num_web_sample = 0
147
+ num_correct_sample = 0
148
+ for ds in TEST_DATASETS:
149
+ if rank == 0:
150
+ print(f"Processing {ds}...")
151
+ ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
152
+ data = json.load(open(ds_path, "r"))
153
+ random.seed(42)
154
+ random.shuffle(data)
155
+ data = data[:num_samples]
156
+
157
+ # Split data for distributed evaluation
158
+ per_rank_data = len(data) // world_size
159
+ start_idx = rank * per_rank_data
160
+ end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
161
+ rank_data = data[start_idx:end_idx]
162
+
163
+ messages = []
164
+
165
+ for x in rank_data:
166
+ image_path = os.path.join(IMAGE_ROOT, x['img_url'])
167
+ width,height = x['img_size'][0],x['img_size'][1]
168
+ resized_height, resized_width = smart_resize(
169
+ height,
170
+ width,
171
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
172
+ min_pixels = processor.image_processor.min_pixels,
173
+ max_pixels = processor.image_processor.max_pixels,
174
+ )
175
+ system_content = """You are a helpful assistant.
176
+ #Tools
177
+
178
+ You may call one or more functions to assist with the user query.
179
+
180
+ You are provided with function signatures within <tools></tools> XML tags:
181
+ <tools>
182
+ {"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
183
+ </tools>
184
+
185
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
186
+ <tool_call>
187
+ {"name": <function-name>, "arguments": <args-json-object>}
188
+ </tool_call>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
189
+ message = [
190
+ {
191
+ "role": "system",
192
+ "content": [
193
+ {
194
+ "type": "text",
195
+ "text": system_content
196
+ }
197
+ ]
198
+ },
199
+ {
200
+ "role": "user",
201
+ "content": [
202
+ {
203
+ "type": "image",
204
+ "image": f"file://{image_path}"
205
+ },
206
+ {
207
+ "type": "text",
208
+ "text": x['task']
209
+ }
210
+ ]
211
+ },
212
+ ]
213
+ # print(message)
214
+ messages.append(message)
215
+
216
+
217
+ rank_outputs = [] # List to store answers for this rank
218
+ all_outputs = [] # List to store all answers
219
+
220
+ # Process data
221
+ for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
222
+ batch_messages = messages[i:i + BSZ]
223
+
224
+ # Preparation for inference
225
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
226
+
227
+ image_inputs, video_inputs = process_vision_info(batch_messages)
228
+ inputs = processor(
229
+ text=text,
230
+ images=image_inputs,
231
+ videos=video_inputs,
232
+ padding=True,
233
+ padding_side="left",
234
+ return_tensors="pt",
235
+ )
236
+ inputs = inputs.to(device)
237
+
238
+ # Inference: Generation of the output
239
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
240
+
241
+ generated_ids_trimmed = [
242
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
243
+ ]
244
+ batch_output_text = processor.batch_decode(
245
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
246
+ )
247
+
248
+ rank_outputs.extend(batch_output_text)
249
+
250
+ print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
251
+
252
+ # Gather all outputs from all ranks
253
+ all_outputs = [None] * len(data)
254
+ rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
255
+
256
+ gathered_results = [None] * world_size
257
+ dist.all_gather_object(gathered_results, rank_results)
258
+
259
+ assert gathered_results[-1][-1][0] == len(data) - 1
260
+
261
+ # The main process will collect all results
262
+ if rank == 0:
263
+ for results in gathered_results:
264
+ for idx, output in results:
265
+ assert idx < len(all_outputs)
266
+ all_outputs[idx] = output
267
+ assert all_outputs[-1] is not None
268
+
269
+ final_output = []
270
+ correct_number = 0
271
+ correct_number_desktop = 0
272
+ correct_number_mobile = 0
273
+ correct_number_web = 0
274
+
275
+ for input_example, model_output in zip(data, all_outputs):
276
+ original_output = model_output
277
+ ground_truth = input_example['bbox']
278
+ split_class = input_example['split']
279
+ ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], (ground_truth[0]+ground_truth[2]) / input_example['img_size'][0], (ground_truth[1]+ground_truth[3]) / input_example['img_size'][1]]
280
+ model_answer = extract_point_answer(original_output)
281
+ resized_height, resized_width = smart_resize(
282
+ input_example['img_size'][1],
283
+ input_example['img_size'][0],
284
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
285
+ min_pixels = processor.image_processor.min_pixels,
286
+ max_pixels = processor.image_processor.max_pixels,
287
+ )
288
+ model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
289
+ # Count correct answers
290
+ correct = 0
291
+ if model_answer is not None:
292
+ correct = point_in_box(model_answer, ground_truth)
293
+ correct_number += correct
294
+ num_all_sample +=1
295
+ num_correct_sample += correct
296
+ if split_class == "desktop":
297
+ correct_number_desktop += correct
298
+ num_desktop_sample += 1
299
+ if split_class == "mobile":
300
+ correct_number_mobile += correct
301
+ num_mobile_sample += 1
302
+ if split_class == "web":
303
+ correct_number_web += correct
304
+ num_web_sample += 1
305
+ # Create a result dictionary for this example
306
+ result = {
307
+ 'image': input_example['img_url'],
308
+ 'question': input_example['task'],
309
+ 'resized_size': [resized_height, resized_width],
310
+ 'ground_truth': ground_truth,
311
+ 'model_output': original_output,
312
+ 'extracted_answer': model_answer,
313
+ 'correct': correct
314
+ }
315
+ final_output.append(result)
316
+
317
+ # Calculate and print accuracy
318
+ accuracy = correct_number / len(data) * 100
319
+ accuracy_desktop = correct_number_desktop / num_desktop_sample * 100
320
+ accuracy_mobile = correct_number_mobile / num_mobile_sample * 100
321
+ accuracy_web = correct_number_web / num_web_sample * 100
322
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
323
+ print(f"Accuracy of desktop: {accuracy_desktop:.2f}%")
324
+ print(f"Accuracy of mobile: {accuracy_mobile:.2f}%")
325
+ print(f"Accuracy of web: {accuracy_web:.2f}%")
326
+
327
+ # Save results to a JSON file
328
+ output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
329
+ output_dir = os.path.dirname(output_path)
330
+ if not os.path.exists(output_dir):
331
+ os.makedirs(output_dir)
332
+ with open(output_path, "w") as f:
333
+ json.dump({
334
+ 'accuracy': accuracy,
335
+ 'results': final_output
336
+ }, f, indent=2)
337
+
338
+ print(f"Results saved to {output_path}")
339
+ print("-"*100)
340
+ # 将最后的统计和打印移到rank==0的条件块内
341
+ if rank == 0:
342
+ accuracy = num_correct_sample / num_all_sample * 100
343
+ print(f"\nnumber of correct samples: {num_correct_sample}")
344
+ print(f"number of all samples: {num_all_sample}")
345
+ print(f"Accuracy of all datasets: {accuracy:.2f}%")
346
+
347
+ # Synchronize all processes
348
+ dist.barrier()
349
+
350
+
351
+
352
+
open-r1-multimodal/configs/qwen2vl_sft_config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: /data/shz/ckpt/Qwen2.5-VL-3B-Instruct
3
+ model_revision: main
4
+ torch_dtype: bfloat16
5
+
6
+ # Data training arguments
7
+ dataset_name: /data/shz/project/vlm-r1/VLM-R1/src/open-r1-multimodal/data_script/rec.yaml
8
+ image_root: /data/shz/dataset/coco
9
+ dataset_configs:
10
+ - all
11
+ preprocessing_num_workers: 8
12
+
13
+ # SFT trainer config
14
+ bf16: true
15
+ do_eval: true
16
+ eval_strategy: "no"
17
+ gradient_accumulation_steps: 2
18
+ gradient_checkpointing: true
19
+ gradient_checkpointing_kwargs:
20
+ use_reentrant: false
21
+ hub_model_id: Qwen2.5-VL-3B-Instruct
22
+ hub_strategy: every_save
23
+ learning_rate: 2.0e-05
24
+ log_level: info
25
+ logging_steps: 5
26
+ logging_strategy: steps
27
+ lr_scheduler_type: cosine
28
+ packing: true
29
+ max_seq_length: 4096
30
+ max_steps: -1
31
+ num_train_epochs: 3
32
+ output_dir: /data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-Instruct-SFT
33
+ overwrite_output_dir: true
34
+ per_device_eval_batch_size: 1
35
+ per_device_train_batch_size: 4
36
+ push_to_hub: false
37
+ report_to:
38
+ - wandb
39
+ save_strategy: "no"
40
+ seed: 42
41
+ data_seed: 42
42
+ warmup_ratio: 0.1
open-r1-multimodal/data_config/gui_grounding.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ datasets:
2
+ - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/showui_desktop_no_position_high_quality_qwen25vl_4028160_attention_0.2_filtered_only_one.json
open-r1-multimodal/data_config/rec_internvl.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets:
2
+ - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcoco_train.json
3
+ - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocop_train.json
4
+ - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocog_train.json
open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19d1823752455bca732cc85c0f7c6327db602e8140044d946e690abc9bb3ad52
3
+ size 30595146
open-r1-multimodal/local_scripts/create_vision_cot_data.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import concurrent.futures
4
+ import io
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from functools import partial
12
+ from io import BytesIO
13
+ from typing import Dict, List
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import pandas as pd
18
+ from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19
+ from tqdm import tqdm
20
+
21
+ import bytedtos
22
+ import seaborn as sns
23
+ import yaml
24
+ from openai import AzureOpenAI
25
+ from PIL import Image
26
+ from pillow_avif import AvifImagePlugin
27
+
28
+
29
+ PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30
+
31
+ Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32
+
33
+ Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34
+
35
+ Input Format:
36
+ Original Question: {original_question}
37
+ Original Answer: {original_answer}
38
+
39
+ Output Format:
40
+ Question: [rewrite the question if necessary]
41
+ Answer: [answer with reasoning steps, including calculations where applicable]
42
+ <think>step-by-step reasoning process</think>
43
+ <answer>easy to verify answer</answer>
44
+ """
45
+
46
+
47
+ def get_image_data_url(image_input):
48
+ if isinstance(image_input, str) and image_input.startswith("data:"):
49
+ return image_input
50
+
51
+ if isinstance(image_input, str) and image_input.startswith("http"):
52
+ image_input = load_image(image_input)
53
+
54
+ if isinstance(image_input, str):
55
+ image_input = Image.open(image_input)
56
+
57
+ if not isinstance(image_input, Image.Image):
58
+ raise ValueError("Unsupported image input type")
59
+
60
+ if image_input.mode != "RGB":
61
+ image_input = image_input.convert("RGB")
62
+
63
+ buffer = BytesIO()
64
+ image_input.save(buffer, format="JPEG")
65
+ img_bytes = buffer.getvalue()
66
+ base64_data = base64.b64encode(img_bytes).decode("utf-8")
67
+ return f"data:image/jpeg;base64,{base64_data}"
68
+
69
+
70
+ def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71
+ if image is None:
72
+ return None
73
+
74
+ data_url_list = [get_image_data_url(image)]
75
+ client = AzureOpenAI(
76
+ azure_endpoint="YOUR_AZURE_ENDPOINT",
77
+ api_version="2023-07-01-preview",
78
+ api_key="YOUR_API_KEY",
79
+ )
80
+
81
+ for attempt in range(max_retries):
82
+ try:
83
+ messages = [
84
+ {
85
+ "role": "system",
86
+ "content": "You are an expert to analyze the image and provide useful information for users.",
87
+ },
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ {"type": "text", "text": prompt},
92
+ ],
93
+ },
94
+ ]
95
+
96
+ for data_url in data_url_list:
97
+ messages[1]["content"].insert(
98
+ 0, {"type": "image_url", "image_url": {"url": data_url}}
99
+ )
100
+
101
+ response = client.chat.completions.create(
102
+ model="gpt-4o-2024-08-06",
103
+ messages=messages,
104
+ temperature=0.2,
105
+ max_tokens=8192,
106
+ )
107
+ return response.choices[0].message.content
108
+
109
+ except Exception as e:
110
+ if attempt == max_retries - 1:
111
+ raise Exception(
112
+ f"Failed after {max_retries} attempts. Last error: {str(e)}"
113
+ )
114
+ delay = initial_delay * (2**attempt) + random.uniform(
115
+ 0, 0.1 * initial_delay * (2**attempt)
116
+ )
117
+ time.sleep(delay)
118
+
119
+
120
+ def process_single_item(example):
121
+ try:
122
+ image_path = example["image_path"]
123
+ formatted_prompt = PROMPT_FORMAT.format(
124
+ original_question=example["question"], original_answer=example["answer"]
125
+ )
126
+
127
+ response = gpt4o_query(image_path, formatted_prompt)
128
+ example["gpt4o_response"] = response
129
+ return example
130
+ except Exception as e:
131
+ print(f"Error processing item: {str(e)}")
132
+ example["gpt4o_response"] = None
133
+ return example
134
+
135
+
136
+ def main():
137
+ dataset_path = "path/to/your/dataset"
138
+ full_dataset = load_from_disk(dataset_path)
139
+
140
+ processed_dataset = full_dataset.map(
141
+ function=partial(process_single_item),
142
+ num_proc=256,
143
+ desc="Processing dataset with GPT-4o",
144
+ keep_in_memory=True,
145
+ )
146
+
147
+ output_path = f"{dataset_path}_processed"
148
+ processed_dataset.save_to_disk(output_path)
149
+ print(f"Processed dataset saved to: {output_path}")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
open-r1-multimodal/local_scripts/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
open-r1-multimodal/setup.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16
+
17
+
18
+ import re
19
+ import shutil
20
+ from pathlib import Path
21
+
22
+ from setuptools import find_packages, setup
23
+
24
+
25
+ # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26
+ stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27
+ if stale_egg_info.exists():
28
+ print(
29
+ (
30
+ "Warning: {} exists.\n\n"
31
+ "If you recently updated open_r1, this is expected,\n"
32
+ "but it may prevent open_r1 from installing in editable mode.\n\n"
33
+ "This directory is automatically generated by Python's packaging tools.\n"
34
+ "I will remove it now.\n\n"
35
+ "See https://github.com/pypa/pip/issues/5466 for details.\n"
36
+ ).format(stale_egg_info)
37
+ )
38
+ shutil.rmtree(stale_egg_info)
39
+
40
+
41
+ # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42
+ # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43
+ _deps = [
44
+ "accelerate>=1.2.1",
45
+ "bitsandbytes>=0.43.0",
46
+ "black>=24.4.2",
47
+ "datasets>=3.2.0",
48
+ "deepspeed==0.15.4",
49
+ "distilabel[vllm,ray,openai]>=1.5.2",
50
+ "einops>=0.8.0",
51
+ "flake8>=6.0.0",
52
+ "hf_transfer>=0.1.4",
53
+ "huggingface-hub[cli]>=0.19.2,<1.0",
54
+ "isort>=5.12.0",
55
+ "liger_kernel==0.5.2",
56
+ # "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57
+ "math-verify", # Used for math verification in grpo
58
+ "packaging>=23.0",
59
+ "parameterized>=0.9.0",
60
+ "pytest",
61
+ "safetensors>=0.3.3",
62
+ "sentencepiece>=0.1.99",
63
+ "torch>=2.5.1",
64
+ "transformers>=4.49.0",
65
+ "trl @ git+https://github.com/huggingface/trl.git@main",
66
+ "vllm==0.6.6.post1",
67
+ "wandb>=0.19.1",
68
+ "pillow",
69
+ ]
70
+
71
+ # this is a lookup table with items like:
72
+ #
73
+ # tokenizers: "tokenizers==0.9.4"
74
+ # packaging: "packaging"
75
+ #
76
+ # some of the values are versioned whereas others aren't.
77
+ deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78
+
79
+
80
+ def deps_list(*pkgs):
81
+ return [deps[pkg] for pkg in pkgs]
82
+
83
+
84
+ extras = {}
85
+ extras["tests"] = deps_list("pytest", "parameterized")
86
+ extras["torch"] = deps_list("torch")
87
+ extras["quality"] = deps_list("black", "isort", "flake8")
88
+ # extras["eval"] = deps_list("lighteval", "math-verify")
89
+ extras["eval"] = deps_list("math-verify")
90
+ extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
91
+
92
+ # core dependencies shared across the whole project - keep this to a bare minimum :)
93
+ install_requires = [
94
+ deps["accelerate"],
95
+ deps["bitsandbytes"],
96
+ deps["einops"],
97
+ deps["datasets"],
98
+ deps["deepspeed"],
99
+ deps["hf_transfer"],
100
+ deps["huggingface-hub"],
101
+ deps["liger_kernel"],
102
+ deps["packaging"], # utilities from PyPA to e.g., compare versions
103
+ deps["safetensors"],
104
+ deps["sentencepiece"],
105
+ deps["transformers"],
106
+ deps["trl"],
107
+ ]
108
+
109
+ setup(
110
+ name="open-r1",
111
+ version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
112
+ author="The Hugging Face team (past and future)",
113
+ author_email="lewis@huggingface.co",
114
+ description="Open R1",
115
+ # long_description=open("README.md", "r", encoding="utf-8").read(),
116
+ long_description_content_type="text/markdown",
117
+ keywords="llm inference-time compute reasoning",
118
+ license="Apache",
119
+ url="https://github.com/huggingface/open-r1",
120
+ package_dir={"": "src"},
121
+ packages=find_packages("src"),
122
+ zip_safe=False,
123
+ extras_require=extras,
124
+ python_requires=">=3.10.9",
125
+ install_requires=install_requires,
126
+ classifiers=[
127
+ "Development Status :: 3 - Alpha",
128
+ "Intended Audience :: Developers",
129
+ "Intended Audience :: Education",
130
+ "Intended Audience :: Science/Research",
131
+ "License :: OSI Approved :: Apache Software License",
132
+ "Operating System :: OS Independent",
133
+ "Programming Language :: Python :: 3",
134
+ "Programming Language :: Python :: 3.10",
135
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
136
+ ],
137
+ )
open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ setup.cfg
3
+ setup.py
4
+ src/open_r1/__init__.py
5
+ src/open_r1/configs.py
6
+ src/open_r1/evaluate.py
7
+ src/open_r1/generate.py
8
+ src/open_r1/grpo.py
9
+ src/open_r1/grpo_gui_grounding.py
10
+ src/open_r1/grpo_jsonl.py
11
+ src/open_r1/grpo_rec.py
12
+ src/open_r1/sft.py
13
+ src/open_r1.egg-info/PKG-INFO
14
+ src/open_r1.egg-info/SOURCES.txt
15
+ src/open_r1.egg-info/dependency_links.txt
16
+ src/open_r1.egg-info/not-zip-safe
17
+ src/open_r1.egg-info/requires.txt
18
+ src/open_r1.egg-info/top_level.txt
19
+ src/open_r1/trainer/__init__.py
20
+ src/open_r1/trainer/grpo_config.py
21
+ src/open_r1/trainer/grpo_trainer.py
22
+ src/open_r1/trainer/qwen_grpo_trainer.py
23
+ src/open_r1/trainer/vllm_grpo_trainer.py
24
+ src/open_r1/utils/__init__.py
25
+ src/open_r1/utils/callbacks.py
26
+ src/open_r1/utils/evaluation.py
27
+ src/open_r1/utils/hub.py
28
+ src/open_r1/utils/math.py
29
+ src/open_r1/vlm_modules/__init__.py
30
+ src/open_r1/vlm_modules/internvl_module.py
31
+ src/open_r1/vlm_modules/qwen_module.py
32
+ src/open_r1/vlm_modules/vlm_module.py
open-r1-multimodal/src/open_r1.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
open-r1-multimodal/src/open_r1/grpo.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # import debugpy
16
+ # try:
17
+ # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
18
+ # debugpy.listen(("localhost", 9501))
19
+ # print("Waiting for debugger attach")
20
+ # debugpy.wait_for_client()
21
+ # except Exception as e:
22
+ # pass
23
+
24
+ import os
25
+ import re
26
+ from datetime import datetime
27
+ from dataclasses import dataclass, field
28
+ from typing import Optional
29
+
30
+ from datasets import load_dataset, load_from_disk
31
+ from transformers import Qwen2VLForConditionalGeneration
32
+
33
+ from math_verify import parse, verify
34
+ from open_r1.trainer import VLMGRPOTrainer
35
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
36
+
37
+
38
+ @dataclass
39
+ class GRPOScriptArguments(ScriptArguments):
40
+ """
41
+ Script arguments for the GRPO training script.
42
+
43
+ Args:
44
+ reward_funcs (`list[str]`):
45
+ List of reward functions. Possible values: 'accuracy', 'format'.
46
+ """
47
+
48
+ reward_funcs: list[str] = field(
49
+ default_factory=lambda: ["accuracy", "format"],
50
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
51
+ )
52
+ max_pixels: Optional[int] = field(
53
+ default=12845056,
54
+ metadata={"help": "Maximum number of pixels for the image"},
55
+ )
56
+ min_pixels: Optional[int] = field(
57
+ default=3136,
58
+ metadata={"help": "Minimum number of pixels for the image"},
59
+ )
60
+
61
+
62
+ def accuracy_reward(completions, solution, **kwargs):
63
+ """Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
64
+ contents = [completion[0]["content"] for completion in completions]
65
+ rewards = []
66
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
67
+ for content, sol in zip(contents, solution):
68
+ reward = 0.0
69
+ # Try symbolic verification first
70
+ try:
71
+ answer = parse(content)
72
+ if float(verify(answer, parse(sol))) > 0:
73
+ reward = 1.0
74
+ except Exception:
75
+ pass # Continue to next verification method if this fails
76
+
77
+ # If symbolic verification failed, try string matching
78
+ if reward == 0.0:
79
+ try:
80
+ # Extract answer from solution if it has think/answer tags
81
+ sol_match = re.search(r'<answer>(.*?)</answer>', sol)
82
+ ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
83
+
84
+ # Extract answer from content if it has think/answer tags
85
+ content_match = re.search(r'<answer>(.*?)</answer>', content)
86
+ student_answer = content_match.group(1).strip() if content_match else content.strip()
87
+
88
+ # Compare the extracted answers
89
+ if student_answer == ground_truth:
90
+ reward = 1.0
91
+ except Exception:
92
+ pass # Keep reward as 0.0 if both methods fail
93
+
94
+ rewards.append(reward)
95
+ if os.getenv("DEBUG_MODE") == "true":
96
+ log_path = os.getenv("LOG_PATH")
97
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
98
+ with open(log_path, "a") as f:
99
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
100
+ f.write(f"Content: {content}\n")
101
+ f.write(f"Solution: {sol}\n")
102
+ return rewards
103
+
104
+
105
+ def format_reward(completions, **kwargs):
106
+ """Reward function that checks if the completion has a specific format."""
107
+ pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
108
+ completion_contents = [completion[0]["content"] for completion in completions]
109
+ matches = [re.match(pattern, content) for content in completion_contents]
110
+ return [1.0 if match else 0.0 for match in matches]
111
+
112
+
113
+ reward_funcs_registry = {
114
+ "accuracy": accuracy_reward,
115
+ "format": format_reward,
116
+ }
117
+
118
+ SYSTEM_PROMPT = (
119
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
120
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
121
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
122
+ "<think> reasoning process here </think><answer> answer here </answer>"
123
+ )
124
+
125
+
126
+ def main(script_args, training_args, model_args):
127
+ # Get reward functions
128
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
129
+ print("reward_funcs:", reward_funcs)
130
+
131
+ # Load the dataset
132
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
133
+
134
+
135
+ # Format into conversation
136
+ def make_conversation(example):
137
+ return {
138
+ "prompt": [
139
+ {"role": "system", "content": SYSTEM_PROMPT},
140
+ {"role": "user", "content": example["problem"]},
141
+ ],
142
+ }
143
+
144
+ # def make_conversation_image(example):
145
+ # return {
146
+ # "prompt": [
147
+ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
148
+ # {
149
+ # "role": "user",
150
+ # "content": [
151
+ # {"type": "image"},
152
+ # {"type": "text", "text": example["problem"]},
153
+ # ],
154
+ # },
155
+ # ],
156
+ # }
157
+
158
+ QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
159
+
160
+ def make_conversation_image(example):
161
+ return {
162
+ "prompt": [
163
+ {
164
+ "role": "user",
165
+ "content": [
166
+ {"type": "image"},
167
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
168
+ ],
169
+ },
170
+ ],
171
+ }
172
+
173
+
174
+ if "image" in dataset[script_args.dataset_train_split].features:
175
+ print("has image in dataset")
176
+ dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
177
+ # dataset = dataset.remove_columns(["original_question", "original_answer"])
178
+
179
+ else:
180
+ print("no image in dataset")
181
+ dataset = dataset.map(make_conversation)
182
+ dataset = dataset.remove_columns("messages")
183
+
184
+
185
+ trainer_cls = VLMGRPOTrainer
186
+
187
+
188
+ # Initialize the GRPO trainer
189
+ trainer = trainer_cls(
190
+ model=model_args.model_name_or_path,
191
+ reward_funcs=reward_funcs,
192
+ args=training_args,
193
+ train_dataset=dataset[script_args.dataset_train_split],
194
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
195
+ peft_config=get_peft_config(model_args),
196
+ attn_implementation=model_args.attn_implementation,
197
+ max_pixels=script_args.max_pixels,
198
+ min_pixels=script_args.min_pixels,
199
+ torch_dtype=model_args.torch_dtype,
200
+ )
201
+
202
+ # Train and push the model to the Hub
203
+ trainer.train()
204
+
205
+ # Save and push to hub
206
+ trainer.save_model(training_args.output_dir)
207
+ if training_args.push_to_hub:
208
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
209
+
210
+
211
+ if __name__ == "__main__":
212
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
213
+ script_args, training_args, model_args = parser.parse_args_and_config()
214
+ main(script_args, training_args, model_args)
open-r1-multimodal/src/open_r1/sft.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Supervised fine-tuning script for decoder language models.
17
+
18
+ Usage:
19
+
20
+ # One 1 node of 8 x H100s
21
+ accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
22
+ --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
23
+ --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
24
+ --learning_rate 2.0e-5 \
25
+ --num_train_epochs 1 \
26
+ --packing \
27
+ --max_seq_length 4096 \
28
+ --per_device_train_batch_size 4 \
29
+ --gradient_accumulation_steps 4 \
30
+ --gradient_checkpointing \
31
+ --bf16 \
32
+ --logging_steps 5 \
33
+ --eval_strategy steps \
34
+ --eval_steps 100 \
35
+ --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
36
+ """
37
+
38
+ import logging
39
+ import os
40
+ import sys
41
+
42
+ import datasets
43
+ import torch
44
+ from torch.utils.data import Dataset
45
+ import transformers
46
+ from datasets import load_dataset
47
+ from transformers import AutoTokenizer, set_seed, AutoProcessor
48
+ from transformers.trainer_utils import get_last_checkpoint
49
+ from open_r1.configs import SFTConfig
50
+ from open_r1.utils.callbacks import get_callbacks
51
+ import yaml
52
+ import json
53
+ import math
54
+ import random
55
+ from PIL import Image
56
+
57
+ from trl import (
58
+ ModelConfig,
59
+ ScriptArguments,
60
+ SFTTrainer,
61
+ TrlParser,
62
+ get_kbit_device_map,
63
+ get_peft_config,
64
+ get_quantization_config,
65
+ )
66
+ from dataclasses import field
67
+ from qwen_vl_utils import process_vision_info
68
+ logger = logging.getLogger(__name__)
69
+ from dataclasses import dataclass
70
+
71
+ @dataclass
72
+ class SFTScriptArguments(ScriptArguments):
73
+ image_root: str = field(default=None, metadata={"help": "The root directory of the image."})
74
+
75
+
76
+ processor = None
77
+
78
+ class LazySupervisedDataset(Dataset):
79
+ def __init__(self, data_path: str, script_args: ScriptArguments):
80
+ super(LazySupervisedDataset, self).__init__()
81
+ self.script_args = script_args
82
+ self.list_data_dict = []
83
+
84
+ if data_path.endswith(".yaml"):
85
+ with open(data_path, "r") as file:
86
+ yaml_data = yaml.safe_load(file)
87
+ datasets = yaml_data.get("datasets")
88
+ # file should be in the format of:
89
+ # datasets:
90
+ # - json_path: xxxx1.json
91
+ # sampling_strategy: first:1000
92
+ # - json_path: xxxx2.json
93
+ # sampling_strategy: end:3000
94
+ # - json_path: xxxx3.json
95
+ # sampling_strategy: random:999
96
+
97
+ for data in datasets:
98
+ json_path = data.get("json_path")
99
+ sampling_strategy = data.get("sampling_strategy", "all")
100
+ sampling_number = None
101
+
102
+ if json_path.endswith(".jsonl"):
103
+ cur_data_dict = []
104
+ with open(json_path, "r") as json_file:
105
+ for line in json_file:
106
+ cur_data_dict.append(json.loads(line.strip()))
107
+ elif json_path.endswith(".json"):
108
+ with open(json_path, "r") as json_file:
109
+ cur_data_dict = json.load(json_file)
110
+ else:
111
+ raise ValueError(f"Unsupported file type: {json_path}")
112
+
113
+ if ":" in sampling_strategy:
114
+ sampling_strategy, sampling_number = sampling_strategy.split(":")
115
+ if "%" in sampling_number:
116
+ sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
117
+ else:
118
+ sampling_number = int(sampling_number)
119
+
120
+ # Apply the sampling strategy
121
+ if sampling_strategy == "first" and sampling_number is not None:
122
+ cur_data_dict = cur_data_dict[:sampling_number]
123
+ elif sampling_strategy == "end" and sampling_number is not None:
124
+ cur_data_dict = cur_data_dict[-sampling_number:]
125
+ elif sampling_strategy == "random" and sampling_number is not None:
126
+ random.shuffle(cur_data_dict)
127
+ cur_data_dict = cur_data_dict[:sampling_number]
128
+ print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
129
+ self.list_data_dict.extend(cur_data_dict)
130
+ else:
131
+ raise ValueError(f"Unsupported file type: {data_path}")
132
+
133
+ def __len__(self):
134
+ return len(self.list_data_dict)
135
+
136
+ def __getitem__(self, i):
137
+ # Format into conversation
138
+ def make_conversation_image(example):
139
+ image_root = self.script_args.image_root
140
+ # print(111, image_root)
141
+ # print(222, example['image'])
142
+ image_path = os.path.join(image_root, example['image'])
143
+ x1, y1, x2, y2 = example["solution"]
144
+ normal_caption = example["normal_caption"]
145
+ return [
146
+ {
147
+ "role": "user",
148
+ "content": [
149
+ {"type": "image", "image": f"file://{image_path}"},
150
+ {"type": "text", "text": example["problem"]},
151
+ ],
152
+ },
153
+ {
154
+ "role": "assistant",
155
+ "content": f'```json\n[\n\t{{"bbox_2d": [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}], "label": "{normal_caption}"}}\n]\n```',
156
+ }
157
+ ]
158
+
159
+ example = self.list_data_dict[i]
160
+ example["messages"] = make_conversation_image(example)
161
+ return example
162
+
163
+
164
+
165
+ def collate_fn(examples):
166
+ texts = [
167
+ processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
168
+ for example in examples
169
+ ]
170
+ image_inputs = []
171
+ for example in examples:
172
+ imgs, vids = process_vision_info(example["messages"])
173
+ image_inputs.append(imgs)
174
+ batch = processor(
175
+ text=texts,
176
+ images=image_inputs,
177
+ return_tensors="pt",
178
+ padding=True,
179
+ )
180
+ labels = batch["input_ids"].clone()
181
+ labels[labels == processor.tokenizer.pad_token_id] = -100
182
+ image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
183
+ labels[labels == image_token_id] = -100
184
+ batch["labels"] = labels
185
+
186
+ return batch
187
+
188
+
189
+ def main(script_args, training_args, model_args):
190
+ # Set seed for reproducibility
191
+ set_seed(training_args.seed)
192
+
193
+ ###############
194
+ # Setup logging
195
+ ###############
196
+ logging.basicConfig(
197
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
198
+ datefmt="%Y-%m-%d %H:%M:%S",
199
+ handlers=[logging.StreamHandler(sys.stdout)],
200
+ )
201
+ log_level = training_args.get_process_log_level()
202
+ logger.setLevel(log_level)
203
+ datasets.utils.logging.set_verbosity(log_level)
204
+ transformers.utils.logging.set_verbosity(log_level)
205
+ transformers.utils.logging.enable_default_handler()
206
+ transformers.utils.logging.enable_explicit_format()
207
+
208
+ # Log on each process a small summary
209
+ logger.warning(
210
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
211
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
212
+ )
213
+ logger.info(f"Model parameters {model_args}")
214
+ logger.info(f"Script parameters {script_args}")
215
+ logger.info(f"Data parameters {training_args}")
216
+
217
+ # Check for last checkpoint
218
+ last_checkpoint = None
219
+ if os.path.isdir(training_args.output_dir):
220
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
221
+ if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
222
+ logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
223
+
224
+ ################
225
+ # Load datasets
226
+ ################
227
+
228
+ dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
229
+
230
+ ################
231
+ # Load tokenizer
232
+ ################
233
+ global processor
234
+ if "vl" in model_args.model_name_or_path.lower():
235
+ processor = AutoProcessor.from_pretrained(
236
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
237
+ )
238
+ logger.info("Using AutoProcessor for vision-language model.")
239
+ else:
240
+ processor = AutoTokenizer.from_pretrained(
241
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
242
+ )
243
+ logger.info("Using AutoTokenizer for text-only model.")
244
+ if hasattr(processor, "pad_token") and processor.pad_token is None:
245
+ processor.pad_token = processor.eos_token
246
+ elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
247
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
248
+
249
+ ###################
250
+ # Model init kwargs
251
+ ###################
252
+ logger.info("*** Initializing model kwargs ***")
253
+ torch_dtype = (
254
+ model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
255
+ )
256
+ quantization_config = get_quantization_config(model_args)
257
+ model_kwargs = dict(
258
+ revision=model_args.model_revision,
259
+ trust_remote_code=model_args.trust_remote_code,
260
+ attn_implementation=model_args.attn_implementation,
261
+ torch_dtype=torch_dtype,
262
+ use_cache=False if training_args.gradient_checkpointing else True,
263
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
264
+ quantization_config=quantization_config,
265
+ )
266
+ # training_args.model_init_kwargs = model_kwargs
267
+ from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
268
+ if "Qwen2-VL" in model_args.model_name_or_path:
269
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
270
+ model_args.model_name_or_path, **model_kwargs
271
+ )
272
+ elif "Qwen2.5-VL" in model_args.model_name_or_path:
273
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
274
+ model_args.model_name_or_path, **model_kwargs
275
+ )
276
+ else:
277
+ raise ValueError(f"Unsupported model: {model_args.model_name_or_path}")
278
+ ############################
279
+ # Initialize the SFT Trainer
280
+ ############################
281
+ training_args.dataset_kwargs = {
282
+ "skip_prepare_dataset": True,
283
+ }
284
+ training_args.remove_unused_columns = False
285
+ trainer = SFTTrainer(
286
+ model=model,
287
+ args=training_args,
288
+ train_dataset=dataset,
289
+ eval_dataset=None,
290
+ processing_class=processor.tokenizer,
291
+ data_collator=collate_fn,
292
+ peft_config=get_peft_config(model_args),
293
+ callbacks=get_callbacks(training_args, model_args),
294
+ )
295
+
296
+ ###############
297
+ # Training loop
298
+ ###############
299
+ logger.info("*** Train ***")
300
+ checkpoint = None
301
+ if training_args.resume_from_checkpoint is not None:
302
+ checkpoint = training_args.resume_from_checkpoint
303
+ elif last_checkpoint is not None:
304
+ checkpoint = last_checkpoint
305
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
306
+ metrics = train_result.metrics
307
+ metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
308
+ trainer.log_metrics("train", metrics)
309
+ trainer.save_metrics("train", metrics)
310
+ trainer.save_state()
311
+
312
+ ##################################
313
+ # Save model and create model card
314
+ ##################################
315
+ logger.info("*** Save model ***")
316
+ trainer.save_model(training_args.output_dir)
317
+ logger.info(f"Model saved to {training_args.output_dir}")
318
+
319
+ # Save everything else on main process
320
+ kwargs = {
321
+ "finetuned_from": model_args.model_name_or_path,
322
+ "dataset": list(script_args.dataset_name),
323
+ "dataset_tags": list(script_args.dataset_name),
324
+ "tags": ["open-r1"],
325
+ }
326
+ if trainer.accelerator.is_main_process:
327
+ trainer.create_model_card(**kwargs)
328
+ # Restore k,v cache for fast inference
329
+ trainer.model.config.use_cache = True
330
+ trainer.model.config.save_pretrained(training_args.output_dir)
331
+ #############
332
+ # push to hub
333
+ #############
334
+
335
+ if training_args.push_to_hub:
336
+ logger.info("Pushing to hub...")
337
+ trainer.push_to_hub(**kwargs)
338
+
339
+
340
+
341
+
342
+ if __name__ == "__main__":
343
+ parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
344
+ script_args, training_args, model_args = parser.parse_args_and_config()
345
+ print(script_args)
346
+ main(script_args, training_args, model_args)
open-r1-multimodal/src/open_r1/trainer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .grpo_trainer import VLMGRPOTrainer
2
+ from .grpo_config import GRPOConfig
3
+ from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
4
+ from .qwen_grpo_trainer import Qwen2VLGRPOTrainer
5
+ __all__ = ["VLMGRPOTrainer",'Qwen2VLGRPOVLLMTrainer', "Qwen2VLGRPOTrainer"]
open-r1-multimodal/src/open_r1/trainer/__pycache__/vllm_grpo_trainer.cpython-310.pyc ADDED
Binary file (18.4 kB). View file
 
open-r1-multimodal/src/open_r1/utils/__pycache__/math.cpython-310.pyc.139714633805856 ADDED
Binary file (3.88 kB). View file
 
open-r1-multimodal/src/open_r1/utils/__pycache__/math.cpython-310.pyc.140170314805280 ADDED
Binary file (3.88 kB). View file
 
open-r1-multimodal/src/open_r1/vlm_modules/__pycache__/internvl_module.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
open-r1-multimodal/src/open_r1/vlm_modules/__pycache__/qwen_module.cpython-310.pyc ADDED
Binary file (9.36 kB). View file
 
open-r1-multimodal/src/open_r1/vlm_modules/qwen_module.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
2
+ from typing import Dict, Any, Union
3
+ from trl.data_utils import maybe_apply_chat_template
4
+ import torch
5
+
6
+ from open_r1.vlm_modules.vlm_module import VLMBaseModule
7
+
8
+ class Qwen2VLModule(VLMBaseModule):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def get_vlm_key(self):
13
+ return "qwen"
14
+
15
+ def get_model_class(self, model_id: str, model_init_kwargs: dict):
16
+ if "Qwen2-VL" in model_id:
17
+ model_cls = Qwen2VLForConditionalGeneration
18
+ elif "Qwen2.5-VL" in model_id:
19
+ model_cls = Qwen2_5_VLForConditionalGeneration
20
+ else:
21
+ raise ValueError(f"Unsupported model: {model_id}")
22
+ return model_cls
23
+
24
+ def post_model_init(self, model, processing_class):
25
+ pass
26
+
27
+ def get_processing_class(self):
28
+ return AutoProcessor
29
+
30
+ def get_vision_modules_keywords(self):
31
+ return ['visual']
32
+
33
+ def get_custom_multimodal_keywords(self):
34
+ return ['pixel_values', 'image_grid_thw']
35
+
36
+ def get_non_generate_params(self):
37
+ return []
38
+
39
+ def get_custom_processing_keywords(self):
40
+ return ['max_pixels', 'min_pixels']
41
+
42
+ def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
43
+ prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
44
+ return prompts_text
45
+
46
+ def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
47
+ # FIXME
48
+ # This could only process pure-multimodal or pure-text inputs
49
+ if len(images) > 0:
50
+ prompt_inputs = processing_class(
51
+ text=prompts_text,
52
+ images=images,
53
+ return_tensors=return_tensors,
54
+ padding=padding,
55
+ padding_side=padding_side,
56
+ add_special_tokens=add_special_tokens)
57
+ else:
58
+ prompt_inputs = processing_class(
59
+ text=prompts_text,
60
+ return_tensors=return_tensors,
61
+ padding=padding,
62
+ padding_side=padding_side,
63
+ add_special_tokens=add_special_tokens)
64
+ return prompt_inputs
65
+
66
+ @staticmethod
67
+ def get_question_template(task_type: str):
68
+ match task_type:
69
+ case "rec":
70
+ return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
71
+ case _:
72
+ return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
73
+
74
+ @staticmethod
75
+ def format_reward_rec(completions, **kwargs):
76
+ """Check if the Qwen model output matches a specific format."""
77
+ import re
78
+
79
+ # pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
80
+ pattern = r"<tool_call>.*?\{.*\[\d+,\s*\d+\].*\}.*?</tool_call>"
81
+ completion_contents = [completion[0]["content"] for completion in completions]
82
+ print(completion_contents)
83
+ print('-'*100)
84
+ # print(completion_contents)
85
+ # print('-'*100)
86
+ matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
87
+ return [1.0 if match else 0.0 for match in matches]
88
+
89
+
90
+ def format_reward(completions, **kwargs):
91
+ pattern = r"<think>.*?</think>\s*<answer>.*?\[.*?{\"bbox_2d\":\s*\[\s*\d+,\s*\d+,\s*\d+,\s*\d+\s*\]\s*,\s*\"label\":\s*\".*?\"\s*}.*?\].*?</answer>"
92
+ completion_contents = [completion[0]["content"] for completion in completions]
93
+ matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
94
+ return [1.0 if match else 0.0 for match in matches]
95
+
96
+
97
+ def point_reward(completions, solution, **kwargs):
98
+ """Calculate reward based on whether the predicted point is inside the bounding box and its distance from the box center."""
99
+ import re
100
+ import json
101
+ import os
102
+ from datetime import datetime
103
+ import math
104
+
105
+ # 从每个 completion 中提取 content
106
+ contents = [completion[0]["content"] for completion in completions]
107
+ rewards = []
108
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
109
+
110
+ # 遍历每个 content 和对应的 solution
111
+ for content, sol in zip(contents, solution):
112
+ reward = 0.0
113
+ log_details = None
114
+ try:
115
+ # 使用正则表达式提取 <tool_call> 标签中的内容
116
+ tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
117
+ if tool_call_match:
118
+ tool_call_content = tool_call_match.group(1).strip()
119
+ # 解析 JSON
120
+ tool_call_json = json.loads(tool_call_content)
121
+ arguments = tool_call_json.get("arguments", {})
122
+ coordinate = arguments.get("coordinate", None)
123
+ # 检查坐标是否是一个长度为 2 的列表
124
+ if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
125
+ x, y = coordinate
126
+ # 确保 x 和 y 是数值类型
127
+ if isinstance(x, (int, float)) and isinstance(y, (int, float)):
128
+ # 提取边界框和图像尺寸
129
+ box = sol[:4] # [x_min, y_min, x_max, y_max]
130
+ img_width, img_height = sol[4], sol[5]
131
+
132
+ # 检查点是否在边界框内
133
+ if box[0] <= x <= box[2] and box[1] <= y <= box[3]:
134
+ base_reward = 1.0
135
+ else:
136
+ base_reward = 0.0
137
+
138
+ # 计算边界框中心
139
+ cx = (box[0] + box[2]) / 2
140
+ cy = (box[1] + box[3]) / 2
141
+
142
+ # 归一化坐标
143
+ nx = x / img_width
144
+ ny = y / img_height
145
+ ncx = cx / img_width
146
+ ncy = cy / img_height
147
+
148
+ # 计算边界框中心到图像四个角的归一化距离
149
+ d1 = math.sqrt((ncx - 0)**2 + (ncy - 0)**2)
150
+ d2 = math.sqrt((ncx - 1)**2 + (ncy - 0)**2)
151
+ d3 = math.sqrt((ncx - 0)**2 + (ncy - 1)**2)
152
+ d4 = math.sqrt((ncx - 1)**2 + (ncy - 1)**2)
153
+ max_d = max(d1, d2, d3, d4)
154
+
155
+ # 计算点到中心的归一化距离
156
+ d = math.sqrt((nx - ncx)**2 + (ny - ncy)**2)
157
+ d_normalized = d / max_d if max_d > 0 else 0
158
+ decay_term = 1 - d_normalized**2 if d <= 1 else 0
159
+
160
+ # 总奖励
161
+ reward = base_reward + decay_term
162
+
163
+ # 为日志记录准备数据
164
+ log_details = {
165
+ "extracted_coordinate": [x, y],
166
+ "base_reward": base_reward,
167
+ "decay_term": decay_term
168
+ }
169
+ except Exception as e:
170
+ # 如果解析失败或发生异常,reward 保持为 0.0
171
+ pass
172
+
173
+ rewards.append(reward)
174
+
175
+ # 如果启用 DEBUG_MODE,则记录详细信息
176
+ if os.getenv("DEBUG_MODE") == "true":
177
+ log_path = os.getenv("LOG_PATH")
178
+ with open(log_path, "a", encoding='utf-8') as f:
179
+ f.write(f"------------- {current_time} Point-in-box reward: {reward} -------------\n")
180
+ f.write(f"Content: {content}\n")
181
+ f.write(f"Solution box: {sol[:4]}\n")
182
+ f.write(f"Image size: {sol[4]}x{sol[5]}\n")
183
+ if log_details:
184
+ f.write(f"Extracted coordinate: {log_details['extracted_coordinate']}\n")
185
+ f.write(f"Base reward: {log_details['base_reward']}\n")
186
+ f.write(f"Decay term: {log_details['decay_term']}\n")
187
+ else:
188
+ f.write("Failed to extract coordinate\n")
189
+
190
+ return rewards
191
+
192
+ @staticmethod
193
+ def iou_reward(completions, solution, **kwargs):
194
+ """Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
195
+ import re
196
+ import os
197
+ from datetime import datetime
198
+ def iou(box1, box2):
199
+ inter_x1 = max(box1[0], box2[0])
200
+ inter_y1 = max(box1[1], box2[1])
201
+ inter_x2 = min(box1[2]-1, box2[2]-1)
202
+ inter_y2 = min(box1[3]-1, box2[3]-1)
203
+ if inter_x1 < inter_x2 and inter_y1 < inter_y2:
204
+ inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
205
+ else:
206
+ inter = 0
207
+ union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
208
+ return float(inter)/union
209
+ contents = [completion[0]["content"] for completion in completions]
210
+ rewards = []
211
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
212
+ answer_tag_pattern = r'<answer>(.*?)</answer>'
213
+ bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
214
+ for content, sol in zip(contents, solution):
215
+ reward = 0.0
216
+ # Try symbolic verification first
217
+ try:
218
+ content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
219
+ if content_answer_match:
220
+ content_answer = content_answer_match.group(1).strip()
221
+ bbox_match = re.search(bbox_pattern, content_answer)
222
+ if bbox_match:
223
+ bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
224
+ # if iou(bbox, sol) > 0.5:
225
+ # reward = 1.0
226
+ reward = iou(bbox, sol)
227
+ except Exception:
228
+ pass # Continue to next verification method if this fails
229
+
230
+ rewards.append(reward)
231
+ if os.getenv("DEBUG_MODE") == "true":
232
+ log_path = os.getenv("LOG_PATH")
233
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
234
+ with open(log_path, "a", encoding='utf-8') as f:
235
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
236
+ f.write(f"Content: {content}\n")
237
+ f.write(f"Solution: {sol}\n")
238
+ return rewards