File size: 6,557 Bytes
893da74
 
 
 
e6b57ff
c4f38db
e6b57ff
c4f38db
e6b57ff
c4f38db
e6b57ff
c4f38db
e6b57ff
c4f38db
e6b57ff
 
 
 
 
 
c4f38db
e6b57ff
c4f38db
 
e6b57ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4f38db
e6b57ff
893da74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
---
tags:
- ml-intern
---
# GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning

This repository implements **GUI-Shift** from the paper ["GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning"](https://arxiv.org/abs/2505.12493) by Longxi Gao, Li Zhang, and Mengwei Xu.

## Overview

GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the **K-step GUI Transition** task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition.

## Key Features

- **K-step GUI Transition**: Self-supervised inverse dynamics learning from GUI trajectory pairs
- **GRPO Training**: Group Relative Policy Optimization for efficient RL training
- **Rule-based Reward Function**: Format reward + Action reward tailored for GUI tasks
- **Data Filtering Pipeline**: Automatic filtering for high-quality training samples
- **Multi-model Support**: Qwen2.5-VL, InternVL3, MimoVL backbones
- **Benchmark Evaluation**: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro

## Architecture

```
gui-shift/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ data_construction/     # K-step GUI Transition data builder
β”‚   β”œβ”€β”€ training/              # GRPO training with GUI rewards
β”‚   β”œβ”€β”€ filtering/               # Data filtering pipeline
β”‚   └── evaluation/              # Benchmark evaluation scripts
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ build_data.sh            # Construct K-step transition data
β”‚   β”œβ”€β”€ train.sh                 # Launch GRPO training
β”‚   β”œβ”€β”€ filter_data.sh           # Run data filtering
β”‚   └── evaluate.sh              # Evaluate on benchmarks
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ grpo_config.yaml         # GRPO hyperparameters
β”‚   └── data_config.yaml         # Data construction settings
└── data/
    └── (AndroidControl trajectories)
```

## Quick Start

### 1. Setup Environment

```bash
conda create -n gui-shift python=3.10
conda activate gui-shift
pip install -r requirements.txt
```

### 2. Prepare Data

Download AndroidControl dataset and construct K-step GUI Transition data:

```bash
bash scripts/build_data.sh \
  --input_dir /path/to/androidcontrol \
  --output_dir ./data/gui_transition \
  --k_values 1 2 3 4 \
  --samples_per_k 2000
```

### 3. Train with GUI-Shift

```bash
bash scripts/train.sh \
  --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
  --data_dir ./data/gui_transition/filtered \
  --output_dir ./checkpoints/gui-shift-qwen \
  --k 1
```

### 4. Evaluate

```bash
bash scripts/evaluate.sh \
  --model_path ./checkpoints/gui-shift-qwen \
  --benchmark androidcontrol_low
```

## K-step GUI Transition

The core self-supervised task:

1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories
2. The model sees both screenshots and predicts the action a_t that transitions S_t β†’ S_{t+1}
3. No textual instructions needed β€” the future state S_{t+k} serves as the visual goal
4. For k > 1, the model must infer temporal dynamics across multiple steps

### Data Format

```json
{
  "id": "episode_001_step_005_k1",
  "image": ["screenshot_t.png", "screenshot_t+k.png"],
  "conversations": [
    {"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"},
    {"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"}
  ],
  "ground_truth_bbox": [300, 460, 340, 500],
  "k": 1
}
```

## Reward Design

### Format Reward (R_f)
- Enforces `<answer>...</answer>` tags in output
- R_f = 1 if format correct, 0 otherwise

### Action Reward (R_a)
- **click / long_press**: Reward = 1 if predicted point falls within ground-truth bounding box
- **scroll**: Reward = 1 if predicted direction matches ground truth
- **open_app / input_text**: Reward = 1 if predicted string matches exactly
- **navigate_back / navigate_home / wait**: Reward = 1 if action type matches

Total reward: R = R_f + R_a

## Training Configuration

| Hyper-parameter | Value |
|----------------|-------|
| Learning rate | 1e-6 (warmup to 0) |
| Temperature | 0.9 |
| Num generations (G) | 8 |
| Num epochs | 4 |
| Max prompt length | 1024 |
| Max completion length | 256 |
| Per-device batch size | 2 |
| Gradient accumulation | 8 |
| Epsilon (clipping) | 0.2 |
| Beta (KL coefficient) | 0.04 |

## Benchmark Results (from paper)

### GUI Task Automation

| Model | AC-Low EM | AC-High EM | GUI Odyssey EM |
|-------|-----------|------------|----------------|
| Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 |
| **GUI-Shift-Qwen (k=1)** | **90.6** ↑6.8 | **70.4** ↑11.2 | **54.8** ↑9.9 |
| InternVL3-8B (base) | 90.0 | 49.8 | 20.3 |
| **GUI-Shift-Intern (k=4)** | 88.0 | 56.6 | 23.3 |

### GUI Grounding

| Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg |
|-------|-------------------|-------------------|
| Qwen2.5-VL-7B (base) | 84.1 | 26.4 |
| **GUI-Shift-Qwen** | **86.6** | **27.9** |

## Data Filtering

The filtering pipeline selects informative and challenging samples:

1. Generate N=8 candidate actions for each state pair using the base model
2. Score each with the rule-based reward
3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses
4. **Discard** samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard)
5. **Keep** samples with mixed correctness (informative for learning)

## Citation

```bibtex
@article{gao2025guishift,
  title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning},
  author={Gao, Longxi and Zhang, Li and Xu, Mengwei},
  journal={arXiv preprint arXiv:2505.12493},
  year={2025}
}
```

## License

Apache-2.0

<!-- ml-intern-provenance -->
## Generated by ML Intern

This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.

- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "luanns/gui-shift"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```

For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.