Upload README.md
Browse files
README.md
CHANGED
|
@@ -1,26 +1,173 @@
|
|
| 1 |
-
---
|
| 2 |
-
tags:
|
| 3 |
-
- ml-intern
|
| 4 |
-
---
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
| 9 |
-
## Generated by ML Intern
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
| 14 |
-
- Source code: https://github.com/huggingface/ml-intern
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 20 |
|
| 21 |
-
model_id = "luanns/gui-shift"
|
| 22 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 23 |
-
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 24 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
| 1 |
+
# GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
This repository implements **GUI-Shift** from the paper ["GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning"](https://arxiv.org/abs/2505.12493) by Longxi Gao, Li Zhang, and Mengwei Xu.
|
| 4 |
|
| 5 |
+
## Overview
|
|
|
|
| 6 |
|
| 7 |
+
GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the **K-step GUI Transition** task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition.
|
| 8 |
|
| 9 |
+
## Key Features
|
|
|
|
| 10 |
|
| 11 |
+
- **K-step GUI Transition**: Self-supervised inverse dynamics learning from GUI trajectory pairs
|
| 12 |
+
- **GRPO Training**: Group Relative Policy Optimization for efficient RL training
|
| 13 |
+
- **Rule-based Reward Function**: Format reward + Action reward tailored for GUI tasks
|
| 14 |
+
- **Data Filtering Pipeline**: Automatic filtering for high-quality training samples
|
| 15 |
+
- **Multi-model Support**: Qwen2.5-VL, InternVL3, MimoVL backbones
|
| 16 |
+
- **Benchmark Evaluation**: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro
|
| 17 |
|
| 18 |
+
## Architecture
|
|
|
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
```
|
| 21 |
+
gui-shift/
|
| 22 |
+
βββ src/
|
| 23 |
+
β βββ data_construction/ # K-step GUI Transition data builder
|
| 24 |
+
β βββ training/ # GRPO training with GUI rewards
|
| 25 |
+
β βββ filtering/ # Data filtering pipeline
|
| 26 |
+
β βββ evaluation/ # Benchmark evaluation scripts
|
| 27 |
+
βββ scripts/
|
| 28 |
+
β βββ build_data.sh # Construct K-step transition data
|
| 29 |
+
β βββ train.sh # Launch GRPO training
|
| 30 |
+
β βββ filter_data.sh # Run data filtering
|
| 31 |
+
β βββ evaluate.sh # Evaluate on benchmarks
|
| 32 |
+
βββ configs/
|
| 33 |
+
β βββ grpo_config.yaml # GRPO hyperparameters
|
| 34 |
+
β βββ data_config.yaml # Data construction settings
|
| 35 |
+
βββ data/
|
| 36 |
+
βββ (AndroidControl trajectories)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Quick Start
|
| 40 |
+
|
| 41 |
+
### 1. Setup Environment
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
conda create -n gui-shift python=3.10
|
| 45 |
+
conda activate gui-shift
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### 2. Prepare Data
|
| 50 |
+
|
| 51 |
+
Download AndroidControl dataset and construct K-step GUI Transition data:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
bash scripts/build_data.sh \
|
| 55 |
+
--input_dir /path/to/androidcontrol \
|
| 56 |
+
--output_dir ./data/gui_transition \
|
| 57 |
+
--k_values 1 2 3 4 \
|
| 58 |
+
--samples_per_k 2000
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### 3. Train with GUI-Shift
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
bash scripts/train.sh \
|
| 65 |
+
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
|
| 66 |
+
--data_dir ./data/gui_transition/filtered \
|
| 67 |
+
--output_dir ./checkpoints/gui-shift-qwen \
|
| 68 |
+
--k 1
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### 4. Evaluate
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
bash scripts/evaluate.sh \
|
| 75 |
+
--model_path ./checkpoints/gui-shift-qwen \
|
| 76 |
+
--benchmark androidcontrol_low
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## K-step GUI Transition
|
| 80 |
+
|
| 81 |
+
The core self-supervised task:
|
| 82 |
+
|
| 83 |
+
1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories
|
| 84 |
+
2. The model sees both screenshots and predicts the action a_t that transitions S_t β S_{t+1}
|
| 85 |
+
3. No textual instructions needed β the future state S_{t+k} serves as the visual goal
|
| 86 |
+
4. For k > 1, the model must infer temporal dynamics across multiple steps
|
| 87 |
+
|
| 88 |
+
### Data Format
|
| 89 |
+
|
| 90 |
+
```json
|
| 91 |
+
{
|
| 92 |
+
"id": "episode_001_step_005_k1",
|
| 93 |
+
"image": ["screenshot_t.png", "screenshot_t+k.png"],
|
| 94 |
+
"conversations": [
|
| 95 |
+
{"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"},
|
| 96 |
+
{"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"}
|
| 97 |
+
],
|
| 98 |
+
"ground_truth_bbox": [300, 460, 340, 500],
|
| 99 |
+
"k": 1
|
| 100 |
+
}
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Reward Design
|
| 104 |
+
|
| 105 |
+
### Format Reward (R_f)
|
| 106 |
+
- Enforces `<answer>...</answer>` tags in output
|
| 107 |
+
- R_f = 1 if format correct, 0 otherwise
|
| 108 |
+
|
| 109 |
+
### Action Reward (R_a)
|
| 110 |
+
- **click / long_press**: Reward = 1 if predicted point falls within ground-truth bounding box
|
| 111 |
+
- **scroll**: Reward = 1 if predicted direction matches ground truth
|
| 112 |
+
- **open_app / input_text**: Reward = 1 if predicted string matches exactly
|
| 113 |
+
- **navigate_back / navigate_home / wait**: Reward = 1 if action type matches
|
| 114 |
+
|
| 115 |
+
Total reward: R = R_f + R_a
|
| 116 |
+
|
| 117 |
+
## Training Configuration
|
| 118 |
+
|
| 119 |
+
| Hyper-parameter | Value |
|
| 120 |
+
|----------------|-------|
|
| 121 |
+
| Learning rate | 1e-6 (warmup to 0) |
|
| 122 |
+
| Temperature | 0.9 |
|
| 123 |
+
| Num generations (G) | 8 |
|
| 124 |
+
| Num epochs | 4 |
|
| 125 |
+
| Max prompt length | 1024 |
|
| 126 |
+
| Max completion length | 256 |
|
| 127 |
+
| Per-device batch size | 2 |
|
| 128 |
+
| Gradient accumulation | 8 |
|
| 129 |
+
| Epsilon (clipping) | 0.2 |
|
| 130 |
+
| Beta (KL coefficient) | 0.04 |
|
| 131 |
+
|
| 132 |
+
## Benchmark Results (from paper)
|
| 133 |
+
|
| 134 |
+
### GUI Task Automation
|
| 135 |
+
|
| 136 |
+
| Model | AC-Low EM | AC-High EM | GUI Odyssey EM |
|
| 137 |
+
|-------|-----------|------------|----------------|
|
| 138 |
+
| Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 |
|
| 139 |
+
| **GUI-Shift-Qwen (k=1)** | **90.6** β6.8 | **70.4** β11.2 | **54.8** β9.9 |
|
| 140 |
+
| InternVL3-8B (base) | 90.0 | 49.8 | 20.3 |
|
| 141 |
+
| **GUI-Shift-Intern (k=4)** | 88.0 | 56.6 | 23.3 |
|
| 142 |
+
|
| 143 |
+
### GUI Grounding
|
| 144 |
+
|
| 145 |
+
| Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg |
|
| 146 |
+
|-------|-------------------|-------------------|
|
| 147 |
+
| Qwen2.5-VL-7B (base) | 84.1 | 26.4 |
|
| 148 |
+
| **GUI-Shift-Qwen** | **86.6** | **27.9** |
|
| 149 |
+
|
| 150 |
+
## Data Filtering
|
| 151 |
+
|
| 152 |
+
The filtering pipeline selects informative and challenging samples:
|
| 153 |
+
|
| 154 |
+
1. Generate N=8 candidate actions for each state pair using the base model
|
| 155 |
+
2. Score each with the rule-based reward
|
| 156 |
+
3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses
|
| 157 |
+
4. **Discard** samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard)
|
| 158 |
+
5. **Keep** samples with mixed correctness (informative for learning)
|
| 159 |
+
|
| 160 |
+
## Citation
|
| 161 |
+
|
| 162 |
+
```bibtex
|
| 163 |
+
@article{gao2025guishift,
|
| 164 |
+
title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning},
|
| 165 |
+
author={Gao, Longxi and Zhang, Li and Xu, Mengwei},
|
| 166 |
+
journal={arXiv preprint arXiv:2505.12493},
|
| 167 |
+
year={2025}
|
| 168 |
+
}
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
## License
|
| 172 |
|
| 173 |
+
Apache-2.0
|