| --- |
| tags: |
| - ml-intern |
| --- |
| # GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning |
|
|
| This repository implements **GUI-Shift** from the paper ["GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning"](https://arxiv.org/abs/2505.12493) by Longxi Gao, Li Zhang, and Mengwei Xu. |
|
|
| ## Overview |
|
|
| GUI-Shift is a self-supervised reinforcement learning framework that enhances Vision-Language Models (VLMs) for GUI agents without relying on costly textual annotations. The core idea is the **K-step GUI Transition** task: given two GUI screenshots (current state S_t and future state S_{t+k}), the model predicts the initial action that caused the transition. |
|
|
| ## Key Features |
|
|
| - **K-step GUI Transition**: Self-supervised inverse dynamics learning from GUI trajectory pairs |
| - **GRPO Training**: Group Relative Policy Optimization for efficient RL training |
| - **Rule-based Reward Function**: Format reward + Action reward tailored for GUI tasks |
| - **Data Filtering Pipeline**: Automatic filtering for high-quality training samples |
| - **Multi-model Support**: Qwen2.5-VL, InternVL3, MimoVL backbones |
| - **Benchmark Evaluation**: AndroidControl, GUI Odyssey, ScreenSpot-v2, ScreenSpot-Pro |
|
|
| ## Architecture |
|
|
| ``` |
| gui-shift/ |
| ├── src/ |
| │ ├── data_construction/ # K-step GUI Transition data builder |
| │ ├── training/ # GRPO training with GUI rewards |
| │ ├── filtering/ # Data filtering pipeline |
| │ └── evaluation/ # Benchmark evaluation scripts |
| ├── scripts/ |
| │ ├── build_data.sh # Construct K-step transition data |
| │ ├── train.sh # Launch GRPO training |
| │ ├── filter_data.sh # Run data filtering |
| │ └── evaluate.sh # Evaluate on benchmarks |
| ├── configs/ |
| │ ├── grpo_config.yaml # GRPO hyperparameters |
| │ └── data_config.yaml # Data construction settings |
| └── data/ |
| └── (AndroidControl trajectories) |
| ``` |
|
|
| ## Quick Start |
|
|
| ### 1. Setup Environment |
|
|
| ```bash |
| conda create -n gui-shift python=3.10 |
| conda activate gui-shift |
| pip install -r requirements.txt |
| ``` |
|
|
| ### 2. Prepare Data |
|
|
| Download AndroidControl dataset and construct K-step GUI Transition data: |
|
|
| ```bash |
| bash scripts/build_data.sh \ |
| --input_dir /path/to/androidcontrol \ |
| --output_dir ./data/gui_transition \ |
| --k_values 1 2 3 4 \ |
| --samples_per_k 2000 |
| ``` |
|
|
| ### 3. Train with GUI-Shift |
|
|
| ```bash |
| bash scripts/train.sh \ |
| --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \ |
| --data_dir ./data/gui_transition/filtered \ |
| --output_dir ./checkpoints/gui-shift-qwen \ |
| --k 1 |
| ``` |
|
|
| ### 4. Evaluate |
|
|
| ```bash |
| bash scripts/evaluate.sh \ |
| --model_path ./checkpoints/gui-shift-qwen \ |
| --benchmark androidcontrol_low |
| ``` |
|
|
| ## K-step GUI Transition |
|
|
| The core self-supervised task: |
|
|
| 1. Extract state pairs (S_t, S_{t+k}) from GUI trajectories |
| 2. The model sees both screenshots and predicts the action a_t that transitions S_t → S_{t+1} |
| 3. No textual instructions needed — the future state S_{t+k} serves as the visual goal |
| 4. For k > 1, the model must infer temporal dynamics across multiple steps |
|
|
| ### Data Format |
|
|
| ```json |
| { |
| "id": "episode_001_step_005_k1", |
| "image": ["screenshot_t.png", "screenshot_t+k.png"], |
| "conversations": [ |
| {"from": "human", "value": "<image><image>What action transitions the first screen to the second screen?"}, |
| {"from": "gpt", "value": "<answer>{\"action_type\": \"click\", \"x\": 320, \"y\": 480}</answer>"} |
| ], |
| "ground_truth_bbox": [300, 460, 340, 500], |
| "k": 1 |
| } |
| ``` |
|
|
| ## Reward Design |
|
|
| ### Format Reward (R_f) |
| - Enforces `<answer>...</answer>` tags in output |
| - R_f = 1 if format correct, 0 otherwise |
|
|
| ### Action Reward (R_a) |
| - **click / long_press**: Reward = 1 if predicted point falls within ground-truth bounding box |
| - **scroll**: Reward = 1 if predicted direction matches ground truth |
| - **open_app / input_text**: Reward = 1 if predicted string matches exactly |
| - **navigate_back / navigate_home / wait**: Reward = 1 if action type matches |
| |
| Total reward: R = R_f + R_a |
| |
| ## Training Configuration |
| |
| | Hyper-parameter | Value | |
| |----------------|-------| |
| | Learning rate | 1e-6 (warmup to 0) | |
| | Temperature | 0.9 | |
| | Num generations (G) | 8 | |
| | Num epochs | 4 | |
| | Max prompt length | 1024 | |
| | Max completion length | 256 | |
| | Per-device batch size | 2 | |
| | Gradient accumulation | 8 | |
| | Epsilon (clipping) | 0.2 | |
| | Beta (KL coefficient) | 0.04 | |
| |
| ## Benchmark Results (from paper) |
| |
| ### GUI Task Automation |
| |
| | Model | AC-Low EM | AC-High EM | GUI Odyssey EM | |
| |-------|-----------|------------|----------------| |
| | Qwen2.5-VL-7B (base) | 83.8 | 59.2 | 44.9 | |
| | **GUI-Shift-Qwen (k=1)** | **90.6** ↑6.8 | **70.4** ↑11.2 | **54.8** ↑9.9 | |
| | InternVL3-8B (base) | 90.0 | 49.8 | 20.3 | |
| | **GUI-Shift-Intern (k=4)** | 88.0 | 56.6 | 23.3 | |
| |
| ### GUI Grounding |
| |
| | Model | ScreenSpot-v2 Avg | ScreenSpot-Pro Avg | |
| |-------|-------------------|-------------------| |
| | Qwen2.5-VL-7B (base) | 84.1 | 26.4 | |
| | **GUI-Shift-Qwen** | **86.6** | **27.9** | |
| |
| ## Data Filtering |
| |
| The filtering pipeline selects informative and challenging samples: |
| |
| 1. Generate N=8 candidate actions for each state pair using the base model |
| 2. Score each with the rule-based reward |
| 3. Compute sample diversity score: proportion of all-correct vs all-incorrect responses |
| 4. **Discard** samples where all 8 responses are either entirely correct or entirely incorrect (too easy/hard) |
| 5. **Keep** samples with mixed correctness (informative for learning) |
| |
| ## Citation |
| |
| ```bibtex |
| @article{gao2025guishift, |
| title={GUI-Shift: Enhancing VLM-based GUI Agents through Self-supervised Reinforcement Learning}, |
| author={Gao, Longxi and Zhang, Li and Xu, Mengwei}, |
| journal={arXiv preprint arXiv:2505.12493}, |
| year={2025} |
| } |
| ``` |
| |
| ## License |
| |
| Apache-2.0 |
| |
| <!-- ml-intern-provenance --> |
| ## Generated by ML Intern |
| |
| This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. |
| |
| - Try ML Intern: https://smolagents-ml-intern.hf.space |
| - Source code: https://github.com/huggingface/ml-intern |
| |
| ## Usage |
| |
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| model_id = "luanns/gui-shift" |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| ``` |
| |
| For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. |
| |