File size: 6,557 Bytes
893da74 e6b57ff c4f38db e6b57ff c4f38db e6b57ff c4f38db e6b57ff c4f38db e6b57ff c4f38db e6b57ff c4f38db e6b57ff c4f38db e6b57ff c4f38db e6b57ff 893da74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | ---
tags:
- ml-intern
---
# GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning
This repository implements **GUI-Shift** from the paper ["GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning"](https://arxiv.org/abs/2505.12493) by Longxi Gao, Li Zhang, and Mengwei Xu.
## Overview
GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the **K-step GUI Transition** task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition.
## Key Features
- **K-step GUI Transition**: Self-supervised inverse dynamics learning from GUI trajectory pairs
- **GRPO Training**: Group Relative Policy Optimization for efficient RL training
- **Rule-based Reward Function**: Format reward + Action reward tailored for GUI tasks
- **Data Filtering Pipeline**: Automatic filtering for high-quality training samples
- **Multi-model Support**: Qwen2.5-VL, InternVL3, MimoVL backbones
- **Benchmark Evaluation**: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro
## Architecture
```
gui-shift/
βββ src/
β βββ data_construction/ # K-step GUI Transition data builder
β βββ training/ # GRPO training with GUI rewards
β βββ filtering/ # Data filtering pipeline
β βββ evaluation/ # Benchmark evaluation scripts
βββ scripts/
β βββ build_data.sh # Construct K-step transition data
β βββ train.sh # Launch GRPO training
β βββ filter_data.sh # Run data filtering
β βββ evaluate.sh # Evaluate on benchmarks
βββ configs/
β βββ grpo_config.yaml # GRPO hyperparameters
β βββ data_config.yaml # Data construction settings
βββ data/
βββ (AndroidControl trajectories)
```
## Quick Start
### 1. Setup Environment
```bash
conda create -n gui-shift python=3.10
conda activate gui-shift
pip install -r requirements.txt
```
### 2. Prepare Data
Download AndroidControl dataset and construct K-step GUI Transition data:
```bash
bash scripts/build_data.sh \
--input_dir /path/to/androidcontrol \
--output_dir ./data/gui_transition \
--k_values 1 2 3 4 \
--samples_per_k 2000
```
### 3. Train with GUI-Shift
```bash
bash scripts/train.sh \
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
--data_dir ./data/gui_transition/filtered \
--output_dir ./checkpoints/gui-shift-qwen \
--k 1
```
### 4. Evaluate
```bash
bash scripts/evaluate.sh \
--model_path ./checkpoints/gui-shift-qwen \
--benchmark androidcontrol_low
```
## K-step GUI Transition
The core self-supervised task:
1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories
2. The model sees both screenshots and predicts the action a_t that transitions S_t β S_{t+1}
3. No textual instructions needed β the future state S_{t+k} serves as the visual goal
4. For k > 1, the model must infer temporal dynamics across multiple steps
### Data Format
```json
{
"id": "episode_001_step_005_k1",
"image": ["screenshot_t.png", "screenshot_t+k.png"],
"conversations": [
{"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"},
{"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"}
],
"ground_truth_bbox": [300, 460, 340, 500],
"k": 1
}
```
## Reward Design
### Format Reward (R_f)
- Enforces `<answer>...</answer>` tags in output
- R_f = 1 if format correct, 0 otherwise
### Action Reward (R_a)
- **click / long_press**: Reward = 1 if predicted point falls within ground-truth bounding box
- **scroll**: Reward = 1 if predicted direction matches ground truth
- **open_app / input_text**: Reward = 1 if predicted string matches exactly
- **navigate_back / navigate_home / wait**: Reward = 1 if action type matches
Total reward: R = R_f + R_a
## Training Configuration
| Hyper-parameter | Value |
|----------------|-------|
| Learning rate | 1e-6 (warmup to 0) |
| Temperature | 0.9 |
| Num generations (G) | 8 |
| Num epochs | 4 |
| Max prompt length | 1024 |
| Max completion length | 256 |
| Per-device batch size | 2 |
| Gradient accumulation | 8 |
| Epsilon (clipping) | 0.2 |
| Beta (KL coefficient) | 0.04 |
## Benchmark Results (from paper)
### GUI Task Automation
| Model | AC-Low EM | AC-High EM | GUI Odyssey EM |
|-------|-----------|------------|----------------|
| Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 |
| **GUI-Shift-Qwen (k=1)** | **90.6** β6.8 | **70.4** β11.2 | **54.8** β9.9 |
| InternVL3-8B (base) | 90.0 | 49.8 | 20.3 |
| **GUI-Shift-Intern (k=4)** | 88.0 | 56.6 | 23.3 |
### GUI Grounding
| Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg |
|-------|-------------------|-------------------|
| Qwen2.5-VL-7B (base) | 84.1 | 26.4 |
| **GUI-Shift-Qwen** | **86.6** | **27.9** |
## Data Filtering
The filtering pipeline selects informative and challenging samples:
1. Generate N=8 candidate actions for each state pair using the base model
2. Score each with the rule-based reward
3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses
4. **Discard** samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard)
5. **Keep** samples with mixed correctness (informative for learning)
## Citation
```bibtex
@article{gao2025guishift,
title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning},
author={Gao, Longxi and Zhang, Li and Xu, Mengwei},
journal={arXiv preprint arXiv:2505.12493},
year={2025}
}
```
## License
Apache-2.0
<!-- ml-intern-provenance -->
## Generated by ML Intern
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "luanns/gui-shift"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
|