| |
| """ |
| GUI-Shift GRPO Training with Custom GUI Reward Functions. |
| |
| Based on VLM-R1 framework, adapted for GUI action prediction. |
| Uses rule-based rewards: format reward + action reward. |
| |
| Paper: GUI-Shift (arXiv:2505.12493) |
| """ |
|
|
| import json |
| import re |
| import os |
| import pathlib |
| from dataclasses import dataclass, field |
| from typing import Optional, List, Dict, Any |
| from datetime import datetime |
|
|
| import torch |
| from transformers import AutoProcessor, AutoTokenizer |
| from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config |
| from open_r1.trainer import VLMGRPOTrainer |
| from open_r1.vlm_modules import get_vlm_module |
|
|
| from open_r1.qwen2_5vl_monkey_patch import ( |
| monkey_patch_qwen2_5vl_flash_attn, |
| monkey_patch_qwen2_5vl_forward, |
| monkey_patch_torch_load, |
| ) |
|
|
| monkey_patch_qwen2_5vl_flash_attn() |
| monkey_patch_torch_load() |
|
|
|
|
| def parse_gui_action(text: str) -> Optional[Dict[str, Any]]: |
| """Extract GUI action from model output text.""" |
| |
| match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL) |
| if not match: |
| return None |
| |
| content = match.group(1).strip() |
| |
| |
| try: |
| action = json.loads(content) |
| return action |
| except json.JSONDecodeError: |
| |
| action_type_match = re.search(r'"action_type"\s*:\s*"([^"]+)"', content) |
| if action_type_match: |
| action_type = action_type_match.group(1) |
| action = {"action_type": action_type} |
| |
| |
| if action_type in ["click", "long_press"]: |
| x_match = re.search(r'"x"\s*:\s*(\d+)', content) |
| y_match = re.search(r'"y"\s*:\s*(\d+)', content) |
| if x_match and y_match: |
| action["x"] = int(x_match.group(1)) |
| action["y"] = int(y_match.group(1)) |
| |
| |
| elif action_type == "scroll": |
| dir_match = re.search(r'"direction"\s*:\s*"([^"]+)"', content) |
| if dir_match: |
| action["direction"] = dir_match.group(1) |
| |
| |
| elif action_type == "open_app": |
| app_match = re.search(r'"app_name"\s*:\s*"([^"]+)"', content) |
| if app_match: |
| action["app_name"] = app_match.group(1) |
| |
| elif action_type == "input_text": |
| text_match = re.search(r'"text"\s*:\s*"([^"]+)"', content) |
| if text_match: |
| action["text"] = text_match.group(1) |
| |
| return action |
| |
| return None |
|
|
|
|
| def gui_format_reward(completions: List[Dict], **kwargs) -> List[float]: |
| """Reward for correct output format with <answer>...</answer> tags.""" |
| rewards = [] |
| for completion in completions: |
| text = completion[0]["content"] |
| match = re.search(r'<answer>.*?</answer>', text, re.DOTALL) |
| rewards.append(1.0 if match else 0.0) |
| return rewards |
|
|
|
|
| def gui_action_reward(completions: List[Dict], solution: List[str], **kwargs) -> List[float]: |
| """ |
| Reward for correct GUI action prediction. |
| |
| Action space: |
| - click / long_press: correct if point within ground-truth bbox |
| - scroll: correct if direction matches |
| - open_app / input_text: correct if string matches exactly |
| - navigate_back / navigate_home / wait: correct if action type matches |
| """ |
| rewards = [] |
| |
| for completion, sol_text in zip(completions, solution): |
| pred_text = completion[0]["content"] |
| pred_action = parse_gui_action(pred_text) |
| gt_action = parse_gui_action(sol_text) |
| |
| if not pred_action or not gt_action: |
| rewards.append(0.0) |
| continue |
| |
| pred_type = pred_action.get("action_type", "") |
| gt_type = gt_action.get("action_type", "") |
| |
| if pred_type != gt_type: |
| rewards.append(0.0) |
| continue |
| |
| |
| if pred_type in ["click", "long_press"]: |
| |
| bbox = gt_action.get("bbox", kwargs.get("ground_truth_bbox", [0, 0, 0, 0])) |
| if not bbox or len(bbox) < 4: |
| rewards.append(0.0) |
| continue |
| |
| x = pred_action.get("x", 0) |
| y = pred_action.get("y", 0) |
| |
| |
| if "x" in gt_action and "y" in gt_action: |
| gt_x = gt_action["x"] |
| gt_y = gt_action["y"] |
| |
| tolerance = 20 |
| if abs(x - gt_x) <= tolerance and abs(y - gt_y) <= tolerance: |
| rewards.append(1.0) |
| else: |
| rewards.append(0.0) |
| else: |
| |
| if bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]: |
| rewards.append(1.0) |
| else: |
| rewards.append(0.0) |
| |
| elif pred_type == "scroll": |
| pred_dir = pred_action.get("direction", "") |
| gt_dir = gt_action.get("direction", "") |
| rewards.append(1.0 if pred_dir == gt_dir else 0.0) |
| |
| elif pred_type == "open_app": |
| pred_app = pred_action.get("app_name", "") |
| gt_app = gt_action.get("app_name", "") |
| rewards.append(1.0 if pred_app == gt_app else 0.0) |
| |
| elif pred_type == "input_text": |
| pred_text = pred_action.get("text", "") |
| gt_text = gt_action.get("text", "") |
| rewards.append(1.0 if pred_text == gt_text else 0.0) |
| |
| elif pred_type in ["navigate_back", "navigate_home", "wait"]: |
| |
| rewards.append(1.0) |
| |
| else: |
| rewards.append(0.0) |
| |
| return rewards |
|
|
|
|
| def gui_combined_reward(completions: List[Dict], solution: List[str], **kwargs) -> List[float]: |
| """Combined reward = format_reward + action_reward.""" |
| format_rewards = gui_format_reward(completions, **kwargs) |
| action_rewards = gui_action_reward(completions, solution, **kwargs) |
| return [f + a for f, a in zip(format_rewards, action_rewards)] |
|
|
|
|
| reward_funcs_registry = { |
| "format": gui_format_reward, |
| "accuracy": gui_action_reward, |
| "combined": gui_combined_reward, |
| } |
|
|
|
|
| @dataclass |
| class GUIGRPOScriptArguments(ScriptArguments): |
| """Extended script arguments for GUI-Shift training.""" |
| data_file_paths: str = field( |
| default=None, |
| metadata={"help": "Paths to data files, separated by ':'"}, |
| ) |
| image_folders: str = field( |
| default=None, |
| metadata={"help": "Paths to image folders, separated by ':'"}, |
| ) |
| val_split_ratio: float = field( |
| default=0.0, |
| metadata={"help": "Ratio of validation split"}, |
| ) |
| reward_funcs: List[str] = field( |
| default_factory=lambda: ["format", "accuracy"], |
| metadata={"help": "List of reward functions to use"}, |
| ) |
| max_pixels: Optional[int] = field( |
| default=12845056, |
| metadata={"help": "Maximum number of pixels for the image (for QwenVL)"}, |
| ) |
| min_pixels: Optional[int] = field( |
| default=3136, |
| metadata={"help": "Minimum number of pixels for the image (for QwenVL)"}, |
| ) |
| task_type: Optional[str] = field( |
| default="gui", |
| metadata={"help": "Task type for GUI action prediction"}, |
| ) |
|
|
|
|
| @dataclass |
| class GUIGRPOModelConfig(ModelConfig): |
| freeze_vision_modules: bool = field( |
| default=True, |
| metadata={"help": "Freeze vision encoder and projector during training"}, |
| ) |
|
|
|
|
| SYSTEM_PROMPT = ( |
| "You are a GUI automation assistant. Given two screenshots showing a GUI before and after an action, " |
| "predict the action that caused the transition. " |
| "Output your answer in the following format: <answer>{\"action_type\": ..., ...}</answer>" |
| ) |
|
|
|
|
| def load_gui_dataset(data_file_paths: str, image_folders: str): |
| """Load GUI transition dataset from JSONL files.""" |
| from datasets import Dataset |
| |
| data_files = data_file_paths.split(":") |
| image_folders_list = image_folders.split(":") |
| |
| if len(data_files) != len(image_folders_list): |
| raise ValueError("Number of data files must match number of image folders") |
| |
| all_data = [] |
| for data_file, image_folder in zip(data_files, image_folders_list): |
| with open(data_file, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| |
| |
| if "image" in item: |
| if isinstance(item["image"], str): |
| item["image_path"] = [os.path.join(image_folder, item["image"])] |
| elif isinstance(item["image"], list): |
| item["image_path"] = [os.path.join(image_folder, img) for img in item["image"]] |
| del item["image"] |
| |
| |
| item["problem"] = item["conversations"][0]["value"].replace("<image>", "").replace("<image><image>", "") |
| |
| solution_value = item["conversations"][1]["value"] |
| if isinstance(solution_value, str): |
| item["solution"] = solution_value.replace("<answer>", "").replace("</answer>", "").strip() |
| else: |
| item["solution"] = str(solution_value) |
| |
| |
| item["ground_truth_bbox"] = item.get("ground_truth_bbox", [0, 0, 0, 0]) |
| item["k"] = item.get("k", 1) |
| |
| del item["conversations"] |
| all_data.append(item) |
| |
| return Dataset.from_list(all_data) |
|
|
|
|
| def main(script_args, training_args, model_args): |
| |
| vlm_module_cls = get_vlm_module(model_args.model_name_or_path) |
| print(f"Using VLM module: {vlm_module_cls.__name__}") |
| |
| |
| reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] |
| print(f"Reward functions: {script_args.reward_funcs}") |
| |
| |
| dataset = load_gui_dataset(script_args.data_file_paths, script_args.image_folders) |
| print(f"Loaded dataset with {len(dataset)} samples") |
| |
| def make_conversation(example): |
| """Convert dataset example to conversation format for GRPO.""" |
| if "image_path" in example and example["image_path"]: |
| |
| images_content = [{"type": "image", "text": None} for _ in example["image_path"]] |
| else: |
| images_content = [] |
| |
| return { |
| "image_path": example.get("image_path", []), |
| "problem": example["problem"], |
| "solution": f"<answer> {example['solution']} </answer>", |
| "ground_truth_bbox": example.get("ground_truth_bbox", [0, 0, 0, 0]), |
| "k": example.get("k", 1), |
| "prompt": [{ |
| "role": "user", |
| "content": [ |
| *images_content, |
| {"type": "text", "text": example["problem"]} |
| ] |
| }] |
| } |
| |
| dataset = dataset.map(make_conversation, num_proc=8) |
| |
| |
| splits = {"train": dataset} |
| if script_args.val_split_ratio > 0: |
| train_val_split = dataset.train_test_split(test_size=script_args.val_split_ratio) |
| splits["train"] = train_val_split["train"] |
| splits["validation"] = train_val_split["test"] |
| |
| |
| trainer = VLMGRPOTrainer( |
| model=model_args.model_name_or_path, |
| reward_funcs=reward_funcs, |
| args=training_args, |
| vlm_module=vlm_module_cls(), |
| train_dataset=splits["train"], |
| eval_dataset=splits.get("validation") if training_args.eval_strategy != "no" else None, |
| peft_config=get_peft_config(model_args), |
| freeze_vision_modules=model_args.freeze_vision_modules, |
| attn_implementation=model_args.attn_implementation, |
| max_pixels=script_args.max_pixels, |
| min_pixels=script_args.min_pixels, |
| ) |
| |
| |
| if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
| trainer.train(resume_from_checkpoint=True) |
| else: |
| trainer.train() |
| |
| |
| trainer.save_model(training_args.output_dir) |
| if training_args.push_to_hub: |
| trainer.push_to_hub() |
|
|
|
|
| if __name__ == "__main__": |
| parser = TrlParser((GUIGRPOScriptArguments, GRPOConfig, GUIGRPOModelConfig)) |
| script_args, training_args, model_args = parser.parse_args_and_config() |
| |
| if training_args.deepspeed and "zero3" in training_args.deepspeed: |
| print("Zero3 detected, applying Qwen2.5-VL forward monkey patch") |
| monkey_patch_qwen2_5vl_forward() |
| |
| main(script_args, training_args, model_args) |
|
|