--- tags: - ml-intern --- # GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning This repository implements **GUI-Shift** from the paper ["GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning"](https://arxiv.org/abs/2505.12493) by Longxi Gao, Li Zhang, and Mengwei Xu. ## Overview GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the **K-step GUI Transition** task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition. ## Key Features - **K-step GUI Transition**: Self-supervised inverse dynamics learning from GUI trajectory pairs - **GRPO Training**: Group Relative Policy Optimization for efficient RL training - **Rule-based Reward Function**: Format reward + Action reward tailored for GUI tasks - **Data Filtering Pipeline**: Automatic filtering for high-quality training samples - **Multi-model Support**: Qwen2.5-VL, InternVL3, MimoVL backbones - **Benchmark Evaluation**: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro ## Architecture ``` gui-shift/ ├── src/ │ ├── data_construction/ # K-step GUI Transition data builder │ ├── training/ # GRPO training with GUI rewards │ ├── filtering/ # Data filtering pipeline │ └── evaluation/ # Benchmark evaluation scripts ├── scripts/ │ ├── build_data.sh # Construct K-step transition data │ ├── train.sh # Launch GRPO training │ ├── filter_data.sh # Run data filtering │ └── evaluate.sh # Evaluate on benchmarks ├── configs/ │ ├── grpo_config.yaml # GRPO hyperparameters │ └── data_config.yaml # Data construction settings └── data/ └── (AndroidControl trajectories) ``` ## Quick Start ### 1. Setup Environment ```bash conda create -n gui-shift python=3.10 conda activate gui-shift pip install -r requirements.txt ``` ### 2. Prepare Data Download AndroidControl dataset and construct K-step GUI Transition data: ```bash bash scripts/build_data.sh \ --input_dir /path/to/androidcontrol \ --output_dir ./data/gui_transition \ --k_values 1 2 3 4 \ --samples_per_k 2000 ``` ### 3. Train with GUI-Shift ```bash bash scripts/train.sh \ --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \ --data_dir ./data/gui_transition/filtered \ --output_dir ./checkpoints/gui-shift-qwen \ --k 1 ``` ### 4. Evaluate ```bash bash scripts/evaluate.sh \ --model_path ./checkpoints/gui-shift-qwen \ --benchmark androidcontrol_low ``` ## K-step GUI Transition The core self-supervised task: 1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories 2. The model sees both screenshots and predicts the action a_t that transitions S_t → S_{t+1} 3. No textual instructions needed — the future state S_{t+k} serves as the visual goal 4. For k > 1, the model must infer temporal dynamics across multiple steps ### Data Format ```json { "id": "episode_001_step_005_k1", "image": ["screenshot_t.png", "screenshot_t+k.png"], "conversations": [ {"from": "human", "value": "What action transitions the first screen to the second screen?"}, {"from": "gpt", "value": "{\"action_type\": \"click\", \"x\": 320, \"y\": 480}"} ], "ground_truth_bbox": [300, 460, 340, 500], "k": 1 } ``` ## Reward Design ### Format Reward (R_f) - Enforces `...` tags in output - R_f = 1 if format correct, 0 otherwise ### Action Reward (R_a) - **click / long_press**: Reward = 1 if predicted point falls within ground-truth bounding box - **scroll**: Reward = 1 if predicted direction matches ground truth - **open_app / input_text**: Reward = 1 if predicted string matches exactly - **navigate_back / navigate_home / wait**: Reward = 1 if action type matches Total reward: R = R_f + R_a ## Training Configuration | Hyper-parameter | Value | |----------------|-------| | Learning rate | 1e-6 (warmup to 0) | | Temperature | 0.9 | | Num generations (G) | 8 | | Num epochs | 4 | | Max prompt length | 1024 | | Max completion length | 256 | | Per-device batch size | 2 | | Gradient accumulation | 8 | | Epsilon (clipping) | 0.2 | | Beta (KL coefficient) | 0.04 | ## Benchmark Results (from paper) ### GUI Task Automation | Model | AC-Low EM | AC-High EM | GUI Odyssey EM | |-------|-----------|------------|----------------| | Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 | | **GUI-Shift-Qwen (k=1)** | **90.6** ↑6.8 | **70.4** ↑11.2 | **54.8** ↑9.9 | | InternVL3-8B (base) | 90.0 | 49.8 | 20.3 | | **GUI-Shift-Intern (k=4)** | 88.0 | 56.6 | 23.3 | ### GUI Grounding | Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg | |-------|-------------------|-------------------| | Qwen2.5-VL-7B (base) | 84.1 | 26.4 | | **GUI-Shift-Qwen** | **86.6** | **27.9** | ## Data Filtering The filtering pipeline selects informative and challenging samples: 1. Generate N=8 candidate actions for each state pair using the base model 2. Score each with the rule-based reward 3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses 4. **Discard** samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard) 5. **Keep** samples with mixed correctness (informative for learning) ## Citation ```bibtex @article{gao2025guishift, title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning}, author={Gao, Longxi and Zhang, Li and Xu, Mengwei}, journal={arXiv preprint arXiv:2505.12493}, year={2025} } ``` ## License Apache-2.0 ## Generated by ML Intern This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. - Try ML Intern: https://smolagents-ml-intern.hf.space - Source code: https://github.com/huggingface/ml-intern ## Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "luanns/gui-shift" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) ``` For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.