| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import os |
| import re |
| from datetime import datetime |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| from datasets import load_dataset, load_from_disk |
| from transformers import Qwen2VLForConditionalGeneration |
|
|
| from math_verify import parse, verify |
| from src.open_r1.trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainer |
| from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config |
|
|
|
|
| @dataclass |
| class GRPOScriptArguments(ScriptArguments): |
| """ |
| Script arguments for the GRPO training script. |
| |
| Args: |
| reward_funcs (`list[str]`): |
| List of reward functions. Possible values: 'accuracy', 'format'. |
| """ |
|
|
| reward_funcs: list[str] = field( |
| default_factory=lambda: ["accuracy", "format"], |
| metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, |
| ) |
| max_pixels: Optional[int] = field( |
| default=12845056, |
| metadata={"help": "Maximum number of pixels for the image"}, |
| ) |
| min_pixels: Optional[int] = field( |
| default=3136, |
| metadata={"help": "Minimum number of pixels for the image"}, |
| ) |
|
|
|
|
| def accuracy_reward(completions, solution, **kwargs): |
| """Reward function that checks if the completion is correct using either symbolic verification or exact string matching.""" |
| contents = [completion[0]["content"] for completion in completions] |
| rewards = [] |
| current_time = datetime.now().strftime("%d-%H-%M-%S-%f") |
| for content, sol in zip(contents, solution): |
| reward = 0.0 |
| |
| try: |
| answer = parse(content) |
| if float(verify(answer, parse(sol))) > 0: |
| reward = 1.0 |
| except Exception: |
| pass |
|
|
| |
| if reward == 0.0: |
| try: |
| |
| sol_match = re.search(r'<answer>(.*?)</answer>', sol) |
| ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() |
| |
| |
| content_match = re.search(r'<answer>(.*?)</answer>', content) |
| student_answer = content_match.group(1).strip() if content_match else content.strip() |
| |
| |
| if student_answer == ground_truth: |
| reward = 1.0 |
| except Exception: |
| pass |
| |
| rewards.append(reward) |
| if os.getenv("DEBUG_MODE") == "true": |
| log_path = os.getenv("LOG_PATH") |
| |
| with open(log_path, "a") as f: |
| try: |
| f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") |
| f.write(f"Content: {content}\n") |
| f.write(f"Solution: {sol}\n") |
| except Exception as e: |
| print(e) |
| return rewards |
|
|
|
|
| def format_reward(completions, **kwargs): |
| """Reward function that checks if the completion has a specific format.""" |
| pattern = r"<think>.*?</think>\s*<answer>.*?</answer>" |
| completion_contents = [completion[0]["content"] for completion in completions] |
| matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents] |
| return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
| reward_funcs_registry = { |
| "accuracy": accuracy_reward, |
| "format": format_reward, |
| } |
|
|
| SYSTEM_PROMPT = ( |
| "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " |
| "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " |
| "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " |
| "<think> reasoning process here </think><answer> answer here </answer>" |
| ) |
|
|
|
|
| def main(script_args, training_args, model_args): |
| |
| reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] |
|
|
| |
| dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) |
|
|
|
|
| |
| def make_conversation(example): |
| return { |
| "prompt": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": example["problem"]}, |
| ], |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags." |
|
|
| def make_conversation_image(example): |
| return { |
| "prompt": [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])}, |
| ], |
| }, |
| ], |
| } |
|
|
|
|
| if "image" in dataset[script_args.dataset_train_split].features: |
| print("has image in dataset") |
| dataset = dataset.map(make_conversation_image) |
| |
|
|
| else: |
| print("no image in dataset") |
| dataset = dataset.map(make_conversation) |
| dataset = dataset.remove_columns("messages") |
|
|
| |
| trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainer |
| print("using: ", trainer_cls) |
|
|
| |
| trainer = trainer_cls( |
| model=model_args.model_name_or_path, |
| reward_funcs=reward_funcs, |
| args=training_args, |
| train_dataset=dataset[script_args.dataset_train_split], |
| eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, |
| peft_config=get_peft_config(model_args), |
| attn_implementation=model_args.attn_implementation, |
| max_pixels=script_args.max_pixels, |
| min_pixels=script_args.min_pixels, |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| trainer.save_model(training_args.output_dir) |
| if training_args.push_to_hub: |
| trainer.push_to_hub(dataset_name=script_args.dataset_name) |
|
|
|
|
| if __name__ == "__main__": |
| parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) |
| script_args, training_args, model_args = parser.parse_args_and_config() |
| main(script_args, training_args, model_args) |
|
|