File size: 6,397 Bytes
3bace00
701bedc
3bace00
 
 
 
529ee99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5427d
 
 
529ee99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5427d
529ee99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d71c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
license: apache-2.0
library_name: transformers
tags: []
---

# Introduction

Reinforcement learning (RL) (e.g., GRPO) helps with grounding because of its inherent objective alignment—rewarding successful clicks—rather than encouraging long textual Chain-of-Thought (CoT) reasoning. Unlike approaches that rely heavily on verbose CoT reasoning, GRPO directly incentivizes actionable and grounded responses. Based on findings from our [blog](https://huggingface.co/blog/HelloKKMe/grounding-r1), we share state-of-the-art GUI grounding models trained using GRPO.

# Performance

We follow the standard evaluation protocol and benchmark our model on three challenging datasets. Our method consistently achieves the best results among all open-source model families. Below are the comparative results:

| **Model**         | **Size** | **Open Source** | **ScreenSpot-V2** | **ScreenSpotPro** | **OSWORLD-G** |
|-------------------|:--------:|:---------------:|:-----------------:|:-----------------:|:-----------------:|
| OpenAI CUA        | —        | ❌              | 87.9              | 23.4              |        —          |
| Claude 3.7        | —        | ❌              | 87.6              | 27.7              |        —          | 
| JEDI-7B           | 7B       | ✅              | 91.7              | 39.5              | 54.1              |
| SE-GUI            | 7B       | ✅              | 90.3              | 47.0              |        —          |
| UI-TARS           | 7B       | ✅              | 91.6              | 35.7              | 47.5              |
| UI-TARS-1.5*       | 7B       | ✅              | 89.7*                 | 42.0*              |  64.2* | 
| UGround-v1-7B     | 7B       | ✅              |  —                | 31.1              |   36.4        |
| Qwen2.5-VL-32B-Instruct | 32B | ✅              |  91.9*                | 48.0              |        59.6*      |    |
| UGround-v1-72B    | 72B      | ✅              |  —                | 34.5              |        —          |
| Qwen2.5-VL-72B-Instruct | 72B | ✅              |  94.00*                | 53.3              |        62.2*          |
| UI-TARS           | 72B      | ✅              | 90.3              | 38.1              |        —          |
| GTA1 (Ours)              | 7B       | ✅              | 92.4 <sub>*(∆ +2.7)*</sub>             | 50.1<sub>*(∆ +8.1)*</sub>              | 67.7 <sub>*(∆ +3.5)*</sub>              |
| GTA1 (Ours)              | 32B      | ✅              | 93.2 <sub>*(∆ +1.3)*</sub>             | 53.6 <sub>*(∆ +5.6)*</sub>             |        61.9<sub>*(∆ +2.3)*</sub>          |
| GTA1 (Ours)              | 72B      | ✅              | 94.8<sub>*(∆ +0.8)*</sub>              | 58.4 <sub>*(∆ +5.1)*</sub>             |        66.7<sub>*(∆ +4.5)*</sub>          |


> **Note:**  
> - Model size is indicated in billions (B) of parameters.  
> - A dash (—) denotes results that are currently unavailable.  
> - A superscript asterisk (﹡) denotes our evaluated result.
> - UI-TARS-1.5 7B, Qwen2.5-VL-32B-Instruct, and Qwen2.5-VL-72B-Instruct are applied as our baseline models.
> - ∆ indicates the performance improvement (∆) of our model compared to its baseline.

# Inference
Below is a code snippet demonstrating how to run inference using a trained model.

```python
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import re

SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.

Output the coordinate pair exactly:
(x,y)
'''
SYSTEM_PROMPT=SYSTEM_PROMPT.strip()

# Function to extract coordinates from model output
def extract_coordinates(raw_string):
    try:
        matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
        return [tuple(map(int, match)) for match in matches][0]
    except:
        return 0,0

# Load model and processor
model_path = "HelloKKMe/GTA1-7B"
max_new_tokens = 32

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(
    model_path,
    min_pixels=3136,
    max_pixels= 4096 * 2160
)

# Load and resize image
image = Image.open("file path")
instruction = "description"  # Instruction for grounding
width, height = image.width, image.height

resized_height, resized_width = smart_resize(
    image.height,
    image.width,
    factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
    min_pixels=processor.image_processor.min_pixels,
    max_pixels=processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height

# Prepare system and user messages
system_message = {
   "role": "system",
   "content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width)
}

user_message = {
    "role": "user",
    "content": [
        {"type": "image", "image": resized_image},
        {"type": "text", "text": instruction}
    ]
}

# Tokenize and prepare inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(model.device)

# Generate prediction
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]

# Extract and rescale coordinates
pred_x, pred_y  = extract_coordinates(output_text) 
pred_x*=scale_x
pred_y*=scale_y 
print(pred_x,pred_y)
```

Refer to our [code](https://github.com/Yan98/GTA1) for more details.