GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning

This repository implements GUI-Shift from the paper "GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning" by Longxi Gao, Li Zhang, and Mengwei Xu.

Overview

GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the K-step GUI Transition task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition.

Key Features

  • K-step GUI Transition: Self-supervised inverse dynamics learning from GUI trajectory pairs
  • GRPO Training: Group Relative Policy Optimization for efficient RL training
  • Rule-based Reward Function: Format reward + Action reward tailored for GUI tasks
  • Data Filtering Pipeline: Automatic filtering for high-quality training samples
  • Multi-model Support: Qwen2.5-VL, InternVL3, MimoVL backbones
  • Benchmark Evaluation: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro

Architecture

gui-shift/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ data_construction/     # K-step GUI Transition data builder
β”‚   β”œβ”€β”€ training/              # GRPO training with GUI rewards
β”‚   β”œβ”€β”€ filtering/               # Data filtering pipeline
β”‚   └── evaluation/              # Benchmark evaluation scripts
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ build_data.sh            # Construct K-step transition data
β”‚   β”œβ”€β”€ train.sh                 # Launch GRPO training
β”‚   β”œβ”€β”€ filter_data.sh           # Run data filtering
β”‚   └── evaluate.sh              # Evaluate on benchmarks
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ grpo_config.yaml         # GRPO hyperparameters
β”‚   └── data_config.yaml         # Data construction settings
└── data/
    └── (AndroidControl trajectories)

Quick Start

1. Setup Environment

conda create -n gui-shift python=3.10
conda activate gui-shift
pip install -r requirements.txt

2. Prepare Data

Download AndroidControl dataset and construct K-step GUI Transition data:

bash scripts/build_data.sh \
  --input_dir /path/to/androidcontrol \
  --output_dir ./data/gui_transition \
  --k_values 1 2 3 4 \
  --samples_per_k 2000

3. Train with GUI-Shift

bash scripts/train.sh \
  --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
  --data_dir ./data/gui_transition/filtered \
  --output_dir ./checkpoints/gui-shift-qwen \
  --k 1

4. Evaluate

bash scripts/evaluate.sh \
  --model_path ./checkpoints/gui-shift-qwen \
  --benchmark androidcontrol_low

K-step GUI Transition

The core self-supervised task:

  1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories
  2. The model sees both screenshots and predicts the action a_t that transitions S_t β†’ S_{t+1}
  3. No textual instructions needed β€” the future state S_{t+k} serves as the visual goal
  4. For k > 1, the model must infer temporal dynamics across multiple steps

Data Format

{
  "id": "episode_001_step_005_k1",
  "image": ["screenshot_t.png", "screenshot_t+k.png"],
  "conversations": [
    {"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"},
    {"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"}
  ],
  "ground_truth_bbox": [300, 460, 340, 500],
  "k": 1
}

Reward Design

Format Reward (R_f)

  • Enforces <answer>...</answer> tags in output
  • R_f = 1 if format correct, 0 otherwise

Action Reward (R_a)

  • click / long_press: Reward = 1 if predicted point falls within ground-truth bounding box
  • scroll: Reward = 1 if predicted direction matches ground truth
  • open_app / input_text: Reward = 1 if predicted string matches exactly
  • navigate_back / navigate_home / wait: Reward = 1 if action type matches

Total reward: R = R_f + R_a

Training Configuration

Hyper-parameter Value
Learning rate 1e-6 (warmup to 0)
Temperature 0.9
Num generations (G) 8
Num epochs 4
Max prompt length 1024
Max completion length 256
Per-device batch size 2
Gradient accumulation 8
Epsilon (clipping) 0.2
Beta (KL coefficient) 0.04

Benchmark Results (from paper)

GUI Task Automation

Model AC-Low EM AC-High EM GUI Odyssey EM
Qwen2.5-VL-7B (base) 83.8 59.2 44.9
GUI-Shift-Qwen (k=1) 90.6 ↑6.8 70.4 ↑11.2 54.8 ↑9.9
InternVL3-8B (base) 90.0 49.8 20.3
GUI-Shift-Intern (k=4) 88.0 56.6 23.3

GUI Grounding

Model ScreenSpot-v2 Avg ScreenSpot-Pro Avg
Qwen2.5-VL-7B (base) 84.1 26.4
GUI-Shift-Qwen 86.6 27.9

Data Filtering

The filtering pipeline selects informative and challenging samples:

  1. Generate N=8 candidate actions for each state pair using the base model
  2. Score each with the rule-based reward
  3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses
  4. Discard samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard)
  5. Keep samples with mixed correctness (informative for learning)

Citation

@article{gao2025guishift,
  title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning},
  author={Gao, Longxi and Zhang, Li and Xu, Mengwei},
  journal={arXiv preprint arXiv:2505.12493},
  year={2025}
}

License

Apache-2.0

Generated by ML Intern

This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "luanns/gui-shift"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for luanns/gui-shift