GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning
This repository implements GUI-Shift from the paper "GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning" by Longxi Gao, Li Zhang, and Mengwei Xu.
Overview
GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the K-step GUI Transition task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition.
Key Features
- K-step GUI Transition: Self-supervised inverse dynamics learning from GUI trajectory pairs
- GRPO Training: Group Relative Policy Optimization for efficient RL training
- Rule-based Reward Function: Format reward + Action reward tailored for GUI tasks
- Data Filtering Pipeline: Automatic filtering for high-quality training samples
- Multi-model Support: Qwen2.5-VL, InternVL3, MimoVL backbones
- Benchmark Evaluation: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro
Architecture
gui-shift/
βββ src/
β βββ data_construction/ # K-step GUI Transition data builder
β βββ training/ # GRPO training with GUI rewards
β βββ filtering/ # Data filtering pipeline
β βββ evaluation/ # Benchmark evaluation scripts
βββ scripts/
β βββ build_data.sh # Construct K-step transition data
β βββ train.sh # Launch GRPO training
β βββ filter_data.sh # Run data filtering
β βββ evaluate.sh # Evaluate on benchmarks
βββ configs/
β βββ grpo_config.yaml # GRPO hyperparameters
β βββ data_config.yaml # Data construction settings
βββ data/
βββ (AndroidControl trajectories)
Quick Start
1. Setup Environment
conda create -n gui-shift python=3.10
conda activate gui-shift
pip install -r requirements.txt
2. Prepare Data
Download AndroidControl dataset and construct K-step GUI Transition data:
bash scripts/build_data.sh \
--input_dir /path/to/androidcontrol \
--output_dir ./data/gui_transition \
--k_values 1 2 3 4 \
--samples_per_k 2000
3. Train with GUI-Shift
bash scripts/train.sh \
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
--data_dir ./data/gui_transition/filtered \
--output_dir ./checkpoints/gui-shift-qwen \
--k 1
4. Evaluate
bash scripts/evaluate.sh \
--model_path ./checkpoints/gui-shift-qwen \
--benchmark androidcontrol_low
K-step GUI Transition
The core self-supervised task:
- Extract state pairs (S_t, S_{t+k}) from GUI trajectories
- The model sees both screenshots and predicts the action a_t that transitions S_t β S_{t+1}
- No textual instructions needed β the future state S_{t+k} serves as the visual goal
- For k > 1, the model must infer temporal dynamics across multiple steps
Data Format
{
"id": "episode_001_step_005_k1",
"image": ["screenshot_t.png", "screenshot_t+k.png"],
"conversations": [
{"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"},
{"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"}
],
"ground_truth_bbox": [300, 460, 340, 500],
"k": 1
}
Reward Design
Format Reward (R_f)
- Enforces
<answer>...</answer>tags in output - R_f = 1 if format correct, 0 otherwise
Action Reward (R_a)
- click / long_press: Reward = 1 if predicted point falls within ground-truth bounding box
- scroll: Reward = 1 if predicted direction matches ground truth
- open_app / input_text: Reward = 1 if predicted string matches exactly
- navigate_back / navigate_home / wait: Reward = 1 if action type matches
Total reward: R = R_f + R_a
Training Configuration
| Hyper-parameter | Value |
|---|---|
| Learning rate | 1e-6 (warmup to 0) |
| Temperature | 0.9 |
| Num generations (G) | 8 |
| Num epochs | 4 |
| Max prompt length | 1024 |
| Max completion length | 256 |
| Per-device batch size | 2 |
| Gradient accumulation | 8 |
| Epsilon (clipping) | 0.2 |
| Beta (KL coefficient) | 0.04 |
Benchmark Results (from paper)
GUI Task Automation
| Model | AC-Low EM | AC-High EM | GUI Odyssey EM |
|---|---|---|---|
| Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 |
| GUI-Shift-Qwen (k=1) | 90.6 β6.8 | 70.4 β11.2 | 54.8 β9.9 |
| InternVL3-8B (base) | 90.0 | 49.8 | 20.3 |
| GUI-Shift-Intern (k=4) | 88.0 | 56.6 | 23.3 |
GUI Grounding
| Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg |
|---|---|---|
| Qwen2.5-VL-7B (base) | 84.1 | 26.4 |
| GUI-Shift-Qwen | 86.6 | 27.9 |
Data Filtering
The filtering pipeline selects informative and challenging samples:
- Generate N=8 candidate actions for each state pair using the base model
- Score each with the rule-based reward
- Compute sample diversity score: proportion of all-correct vs all-incorrect responses
- Discard samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard)
- Keep samples with mixed correctness (informative for learning)
Citation
@article{gao2025guishift,
title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning},
author={Gao, Longxi and Zhang, Li and Xu, Mengwei},
journal={arXiv preprint arXiv:2505.12493},
year={2025}
}
License
Apache-2.0
Generated by ML Intern
This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "luanns/gui-shift"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.