#!/usr/bin/env python3 """ GUI-Shift GRPO Training with Custom GUI Reward Functions. Based on VLM-R1 framework, adapted for GUI action prediction. Uses rule-based rewards: format reward + action reward. Paper: GUI-Shift (arXiv:2505.12493) """ import json import re import os import pathlib from dataclasses import dataclass, field from typing import Optional, List, Dict, Any from datetime import datetime import torch from transformers import AutoProcessor, AutoTokenizer from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config from open_r1.trainer import VLMGRPOTrainer from open_r1.vlm_modules import get_vlm_module from open_r1.qwen2_5vl_monkey_patch import ( monkey_patch_qwen2_5vl_flash_attn, monkey_patch_qwen2_5vl_forward, monkey_patch_torch_load, ) monkey_patch_qwen2_5vl_flash_attn() monkey_patch_torch_load() def parse_gui_action(text: str) -> Optional[Dict[str, Any]]: """Extract GUI action from model output text.""" # Extract content between tags match = re.search(r'(.*?)', text, re.DOTALL) if not match: return None content = match.group(1).strip() # Try to parse as JSON try: action = json.loads(content) return action except json.JSONDecodeError: # Fallback: try regex parsing action_type_match = re.search(r'"action_type"\s*:\s*"([^"]+)"', content) if action_type_match: action_type = action_type_match.group(1) action = {"action_type": action_type} # Extract coordinates for click/long_press if action_type in ["click", "long_press"]: x_match = re.search(r'"x"\s*:\s*(\d+)', content) y_match = re.search(r'"y"\s*:\s*(\d+)', content) if x_match and y_match: action["x"] = int(x_match.group(1)) action["y"] = int(y_match.group(1)) # Extract direction for scroll elif action_type == "scroll": dir_match = re.search(r'"direction"\s*:\s*"([^"]+)"', content) if dir_match: action["direction"] = dir_match.group(1) # Extract text for open_app/input_text elif action_type == "open_app": app_match = re.search(r'"app_name"\s*:\s*"([^"]+)"', content) if app_match: action["app_name"] = app_match.group(1) elif action_type == "input_text": text_match = re.search(r'"text"\s*:\s*"([^"]+)"', content) if text_match: action["text"] = text_match.group(1) return action return None def gui_format_reward(completions: List[Dict], **kwargs) -> List[float]: """Reward for correct output format with ... tags.""" rewards = [] for completion in completions: text = completion[0]["content"] match = re.search(r'.*?', text, re.DOTALL) rewards.append(1.0 if match else 0.0) return rewards def gui_action_reward(completions: List[Dict], solution: List[str], **kwargs) -> List[float]: """ Reward for correct GUI action prediction. Action space: - click / long_press: correct if point within ground-truth bbox - scroll: correct if direction matches - open_app / input_text: correct if string matches exactly - navigate_back / navigate_home / wait: correct if action type matches """ rewards = [] for completion, sol_text in zip(completions, solution): pred_text = completion[0]["content"] pred_action = parse_gui_action(pred_text) gt_action = parse_gui_action(sol_text) if not pred_action or not gt_action: rewards.append(0.0) continue pred_type = pred_action.get("action_type", "") gt_type = gt_action.get("action_type", "") if pred_type != gt_type: rewards.append(0.0) continue # Action type matches, check parameters if pred_type in ["click", "long_press"]: # Check if predicted point is within ground truth bbox bbox = gt_action.get("bbox", kwargs.get("ground_truth_bbox", [0, 0, 0, 0])) if not bbox or len(bbox) < 4: rewards.append(0.0) continue x = pred_action.get("x", 0) y = pred_action.get("y", 0) # For solutions stored as x,y coordinates if "x" in gt_action and "y" in gt_action: gt_x = gt_action["x"] gt_y = gt_action["y"] # Use a tolerance window (e.g., 20 pixels) tolerance = 20 if abs(x - gt_x) <= tolerance and abs(y - gt_y) <= tolerance: rewards.append(1.0) else: rewards.append(0.0) else: # Check if within bounding box if bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]: rewards.append(1.0) else: rewards.append(0.0) elif pred_type == "scroll": pred_dir = pred_action.get("direction", "") gt_dir = gt_action.get("direction", "") rewards.append(1.0 if pred_dir == gt_dir else 0.0) elif pred_type == "open_app": pred_app = pred_action.get("app_name", "") gt_app = gt_action.get("app_name", "") rewards.append(1.0 if pred_app == gt_app else 0.0) elif pred_type == "input_text": pred_text = pred_action.get("text", "") gt_text = gt_action.get("text", "") rewards.append(1.0 if pred_text == gt_text else 0.0) elif pred_type in ["navigate_back", "navigate_home", "wait"]: # Action type already matched above rewards.append(1.0) else: rewards.append(0.0) return rewards def gui_combined_reward(completions: List[Dict], solution: List[str], **kwargs) -> List[float]: """Combined reward = format_reward + action_reward.""" format_rewards = gui_format_reward(completions, **kwargs) action_rewards = gui_action_reward(completions, solution, **kwargs) return [f + a for f, a in zip(format_rewards, action_rewards)] reward_funcs_registry = { "format": gui_format_reward, "accuracy": gui_action_reward, "combined": gui_combined_reward, } @dataclass class GUIGRPOScriptArguments(ScriptArguments): """Extended script arguments for GUI-Shift training.""" data_file_paths: str = field( default=None, metadata={"help": "Paths to data files, separated by ':'"}, ) image_folders: str = field( default=None, metadata={"help": "Paths to image folders, separated by ':'"}, ) val_split_ratio: float = field( default=0.0, metadata={"help": "Ratio of validation split"}, ) reward_funcs: List[str] = field( default_factory=lambda: ["format", "accuracy"], metadata={"help": "List of reward functions to use"}, ) max_pixels: Optional[int] = field( default=12845056, metadata={"help": "Maximum number of pixels for the image (for QwenVL)"}, ) min_pixels: Optional[int] = field( default=3136, metadata={"help": "Minimum number of pixels for the image (for QwenVL)"}, ) task_type: Optional[str] = field( default="gui", metadata={"help": "Task type for GUI action prediction"}, ) @dataclass class GUIGRPOModelConfig(ModelConfig): freeze_vision_modules: bool = field( default=True, metadata={"help": "Freeze vision encoder and projector during training"}, ) SYSTEM_PROMPT = ( "You are a GUI automation assistant. Given two screenshots showing a GUI before and after an action, " "predict the action that caused the transition. " "Output your answer in the following format: {\"action_type\": ..., ...}" ) def load_gui_dataset(data_file_paths: str, image_folders: str): """Load GUI transition dataset from JSONL files.""" from datasets import Dataset data_files = data_file_paths.split(":") image_folders_list = image_folders.split(":") if len(data_files) != len(image_folders_list): raise ValueError("Number of data files must match number of image folders") all_data = [] for data_file, image_folder in zip(data_files, image_folders_list): with open(data_file, "r") as f: for line in f: item = json.loads(line) # Store image paths if "image" in item: if isinstance(item["image"], str): item["image_path"] = [os.path.join(image_folder, item["image"])] elif isinstance(item["image"], list): item["image_path"] = [os.path.join(image_folder, img) for img in item["image"]] del item["image"] # Extract problem and solution from conversations item["problem"] = item["conversations"][0]["value"].replace("", "").replace("", "") solution_value = item["conversations"][1]["value"] if isinstance(solution_value, str): item["solution"] = solution_value.replace("", "").replace("", "").strip() else: item["solution"] = str(solution_value) # Store ground truth bbox if available item["ground_truth_bbox"] = item.get("ground_truth_bbox", [0, 0, 0, 0]) item["k"] = item.get("k", 1) del item["conversations"] all_data.append(item) return Dataset.from_list(all_data) def main(script_args, training_args, model_args): # Load VLM module vlm_module_cls = get_vlm_module(model_args.model_name_or_path) print(f"Using VLM module: {vlm_module_cls.__name__}") # Get reward functions reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] print(f"Reward functions: {script_args.reward_funcs}") # Load dataset dataset = load_gui_dataset(script_args.data_file_paths, script_args.image_folders) print(f"Loaded dataset with {len(dataset)} samples") def make_conversation(example): """Convert dataset example to conversation format for GRPO.""" if "image_path" in example and example["image_path"]: # Multi-image input (before + after screenshots) images_content = [{"type": "image", "text": None} for _ in example["image_path"]] else: images_content = [] return { "image_path": example.get("image_path", []), "problem": example["problem"], "solution": f" {example['solution']} ", "ground_truth_bbox": example.get("ground_truth_bbox", [0, 0, 0, 0]), "k": example.get("k", 1), "prompt": [{ "role": "user", "content": [ *images_content, {"type": "text", "text": example["problem"]} ] }] } dataset = dataset.map(make_conversation, num_proc=8) # Split dataset for validation if requested splits = {"train": dataset} if script_args.val_split_ratio > 0: train_val_split = dataset.train_test_split(test_size=script_args.val_split_ratio) splits["train"] = train_val_split["train"] splits["validation"] = train_val_split["test"] # Initialize trainer trainer = VLMGRPOTrainer( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, vlm_module=vlm_module_cls(), train_dataset=splits["train"], eval_dataset=splits.get("validation") if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), freeze_vision_modules=model_args.freeze_vision_modules, attn_implementation=model_args.attn_implementation, max_pixels=script_args.max_pixels, min_pixels=script_args.min_pixels, ) # Train if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub() if __name__ == "__main__": parser = TrlParser((GUIGRPOScriptArguments, GRPOConfig, GUIGRPOModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() if training_args.deepspeed and "zero3" in training_args.deepspeed: print("Zero3 detected, applying Qwen2.5-VL forward monkey patch") monkey_patch_qwen2_5vl_forward() main(script_args, training_args, model_args)