gui-shift / README.md
luanns's picture
Update ML Intern artifact metadata
893da74 verified
---
tags:
- ml-intern
---
# GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning
This repository implements **GUI-Shift** from the paper ["GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning"](https://arxiv.org/abs/2505.12493) by Longxi Gao, Li Zhang, and Mengwei Xu.
## Overview
GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the **K-step GUI Transition** task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition.
## Key Features
- **K-step GUI Transition**: Self-supervised inverse dynamics learning from GUI trajectory pairs
- **GRPO Training**: Group Relative Policy Optimization for efficient RL training
- **Rule-based Reward Function**: Format reward + Action reward tailored for GUI tasks
- **Data Filtering Pipeline**: Automatic filtering for high-quality training samples
- **Multi-model Support**: Qwen2.5-VL, InternVL3, MimoVL backbones
- **Benchmark Evaluation**: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro
## Architecture
```
gui-shift/
├── src/
│ ├── data_construction/ # K-step GUI Transition data builder
│ ├── training/ # GRPO training with GUI rewards
│ ├── filtering/ # Data filtering pipeline
│ └── evaluation/ # Benchmark evaluation scripts
├── scripts/
│ ├── build_data.sh # Construct K-step transition data
│ ├── train.sh # Launch GRPO training
│ ├── filter_data.sh # Run data filtering
│ └── evaluate.sh # Evaluate on benchmarks
├── configs/
│ ├── grpo_config.yaml # GRPO hyperparameters
│ └── data_config.yaml # Data construction settings
└── data/
└── (AndroidControl trajectories)
```
## Quick Start
### 1. Setup Environment
```bash
conda create -n gui-shift python=3.10
conda activate gui-shift
pip install -r requirements.txt
```
### 2. Prepare Data
Download AndroidControl dataset and construct K-step GUI Transition data:
```bash
bash scripts/build_data.sh \
--input_dir /path/to/androidcontrol \
--output_dir ./data/gui_transition \
--k_values 1 2 3 4 \
--samples_per_k 2000
```
### 3. Train with GUI-Shift
```bash
bash scripts/train.sh \
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
--data_dir ./data/gui_transition/filtered \
--output_dir ./checkpoints/gui-shift-qwen \
--k 1
```
### 4. Evaluate
```bash
bash scripts/evaluate.sh \
--model_path ./checkpoints/gui-shift-qwen \
--benchmark androidcontrol_low
```
## K-step GUI Transition
The core self-supervised task:
1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories
2. The model sees both screenshots and predicts the action a_t that transitions S_t → S_{t+1}
3. No textual instructions needed — the future state S_{t+k} serves as the visual goal
4. For k > 1, the model must infer temporal dynamics across multiple steps
### Data Format
```json
{
"id": "episode_001_step_005_k1",
"image": ["screenshot_t.png", "screenshot_t+k.png"],
"conversations": [
{"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"},
{"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"}
],
"ground_truth_bbox": [300, 460, 340, 500],
"k": 1
}
```
## Reward Design
### Format Reward (R_f)
- Enforces `<answer>...</answer>` tags in output
- R_f = 1 if format correct, 0 otherwise
### Action Reward (R_a)
- **click / long_press**: Reward = 1 if predicted point falls within ground-truth bounding box
- **scroll**: Reward = 1 if predicted direction matches ground truth
- **open_app / input_text**: Reward = 1 if predicted string matches exactly
- **navigate_back / navigate_home / wait**: Reward = 1 if action type matches
Total reward: R = R_f + R_a
## Training Configuration
| Hyper-parameter | Value |
|----------------|-------|
| Learning rate | 1e-6 (warmup to 0) |
| Temperature | 0.9 |
| Num generations (G) | 8 |
| Num epochs | 4 |
| Max prompt length | 1024 |
| Max completion length | 256 |
| Per-device batch size | 2 |
| Gradient accumulation | 8 |
| Epsilon (clipping) | 0.2 |
| Beta (KL coefficient) | 0.04 |
## Benchmark Results (from paper)
### GUI Task Automation
| Model | AC-Low EM | AC-High EM | GUI Odyssey EM |
|-------|-----------|------------|----------------|
| Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 |
| **GUI-Shift-Qwen (k=1)** | **90.6** ↑6.8 | **70.4** ↑11.2 | **54.8** ↑9.9 |
| InternVL3-8B (base) | 90.0 | 49.8 | 20.3 |
| **GUI-Shift-Intern (k=4)** | 88.0 | 56.6 | 23.3 |
### GUI Grounding
| Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg |
|-------|-------------------|-------------------|
| Qwen2.5-VL-7B (base) | 84.1 | 26.4 |
| **GUI-Shift-Qwen** | **86.6** | **27.9** |
## Data Filtering
The filtering pipeline selects informative and challenging samples:
1. Generate N=8 candidate actions for each state pair using the base model
2. Score each with the rule-based reward
3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses
4. **Discard** samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard)
5. **Keep** samples with mixed correctness (informative for learning)
## Citation
```bibtex
@article{gao2025guishift,
title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning},
author={Gao, Longxi and Zhang, Li and Xu, Mengwei},
journal={arXiv preprint arXiv:2505.12493},
year={2025}
}
```
## License
Apache-2.0
<!-- ml-intern-provenance -->
## Generated by ML Intern
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "luanns/gui-shift"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.