File size: 9,037 Bytes
38e1b8c a416ca8 38e1b8c a416ca8 bcd6528 758b134 a416ca8 758b134 a416ca8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
---
base_model:
- Qwen/Qwen2.5-VL-3B-Instruct
language:
- en
license: apache-2.0
tags:
- gui
- agent
pipeline_tag: image-text-to-text
library_name: transformers
---
# InfiGUI-R1-3B
This repository contains the model from the [InfiGUI-R1](https://arxiv.org/abs/2504.14239) paper. The model is based on `Qwen2.5-VL-3B-Instruct` and trained using the proposed Actor2Reasoner framework, enhanced through reinforcement learning to improve its planning and reflection capabilities for GUI tasks.
## Quick Start
### Installation
First install required dependencies:
```bash
pip install transformers qwen-vl-utils
```
### An Example of GUI Grounding & Trajectory Task
```python
import cv2
import json
import torch
import requests
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info, smart_resize
MAX_IMAGE_PIXELS = 5600*28*28
# Load model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Reallm-Labs/InfiGUI-R1-3B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
processor = AutoProcessor.from_pretrained("Reallm-Labs/InfiGUI-R1-3B", max_pixels=MAX_IMAGE_PIXELS, padding_side="left")
# Prepare image
img_url = "https://raw.githubusercontent.com/Reallm-Labs/InfiGUI-R1/main/images/test_img.png"
response = requests.get(img_url)
with open("test_img.png", "wb") as f:
f.write(response.content)
image = Image.open("test_img.png")
width, height = image.size
new_height, new_width = smart_resize(height, width, max_pixels=MAX_IMAGE_PIXELS)
# Prepare inputs
instruction = "View detailed storage space usage"
system_prompt = "You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags."
tool_prompt = "
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{\"type\": \"function\", \"function\": {\"name\": \"mobile_use\", \"description\": \"Use a touchscreen to interact with a mobile device, and take screenshots.\
* This is an interface to a mobile device with touchscreen. You can perform actions like clicking, typing, swiping, etc.\
* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions.\
* The screen's resolution is " + str(new_width) + "x" + str(new_height) + ".\
* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\", \"parameters\": {\"properties\": {\"action\": {\"description\": \"The action to perform. The available actions are:\
* `key`: Perform a key event on the mobile device.\
- This supports adb's `keyevent` syntax.\
- Examples: \\\"volume_up\\\", \\\"volume_down\\\", \\\"power\\\", \\\"camera\\\", \\\"clear\\\".\
* `click`: Click the point on the screen with coordinate (x, y).\
* `long_press`: Press the point on the screen with coordinate (x, y) for specified seconds.\
* `swipe`: Swipe from the starting point with coordinate (x, y) to the end point with coordinates2 (x2, y2).\
* `type`: Input the specified text into the activated input box.\
* `system_button`: Press the system button.\
* `open`: Open an app on the device.\
* `wait`: Wait specified seconds for the change to happen.\
* `terminate`: Terminate the current task and report its completion status.\", \"enum\": [\"key\", \"click\", \"long_press\", \"swipe\", \"type\", \"system_button\", \"open\", \"wait\", \"terminate\"], \"type\": \"string\"}, \"coordinate\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=click`, `action=long_press`, and `action=swipe`.\", \"type\": \"array\"}, \"coordinate2\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=swipe`.\", \"type\": \"array\"}, \"text\": {\"description\": \"Required only by `action=key`, `action=type`, and `action=open`.\", \"type\": \"string\"}, \"time\": {\"description\": \"The seconds to wait. Required only by `action=long_press` and `action=wait`.\", \"type\": \"number\"}, \"button\": {\"description\": \"Back means returning to the previous interface, Home means returning to the desktop, Menu means opening the application background menu, and Enter means pressing the enter. Required only by `action=system_button`\", \"enum\": [\"Back\", \"Home\", \"Menu\", \"Enter\"], \"type\": \"string\"}, \"status\": {\"description\": \"The status of the task. Required only by `action=terminate`.\", \"type\": \"string\", \"enum\": [\"success\", \"failure\"]}}, \"required\": [\"action\"], \"type\": \"object\"}}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{\"name\": <function-name>, \"arguments\": <args-json-object>}
</tool_call>"
grounding_prompt = f'''The screen's resolution is {new_width}x{new_height}.
Point to the UI element most relevant to "{instruction}", output its coordinates using JSON format:
```json
[
{{"point_2d": [x, y], "label": "object name/description"}}
]```'''
trajectory_prompt = f"The user query: {instruction}
Task progress (You have done the following operation on the current device): "
# Build messages
grounding_messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image", "image": "test_img.png"},
{"type": "text", "text": grounding_prompt}
]
}
]
trajectory_messages = [
{"role": "system", "content": system_prompt + tool_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": trajectory_prompt},
{"type": "image", "image": "test_img.png"}
],
},
]
messages = [grounding_messages, trajectory_messages]
# Process and generate
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=512)
output_text = processor.batch_decode(
[out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
# Visualize results
output_text = [ot.split("</think>")[-1] for ot in output_text]
grounding_output = output_text[0].replace("```json", "").replace("```", "").strip()
trajectory_output = output_text[1].replace("<tool_call>", "").replace("</tool_call>", "").strip()
try:
grounding_output = json.loads(grounding_output)
trajectory_output = json.loads(trajectory_output)
grounding_coords = grounding_output[0]['point_2d']
trajectory_coords = trajectory_output["arguments"]['coordinate'] if "coordinate" in trajectory_output["arguments"] else None
grounding_label = grounding_output[0]['label']
trajectory_label = json.dumps(trajectory_output["arguments"])
# Load the original image
img = cv2.imread("test_img.png")
if img is None:
raise ValueError("Could not load the image")
height, width = img.shape[:2]
# Create copies for each visualization
grounding_img = img.copy()
trajectory_img = img.copy()
# Visualize grounding coordinates
if grounding_coords:
x = int(grounding_coords[0] / new_width * width)
y = int(grounding_coords[1] / new_height * height)
cv2.circle(grounding_img, (x, y), 10, (0, 0, 255), -1)
cv2.putText(grounding_img, grounding_label, (x+10, y-10),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
cv2.imwrite("grounding_output.png", grounding_img)
print("Predicted coordinates:", grounding_coords)
print(f"Grounding visualization saved to grounding_output.png")
# Visualize trajectory coordinates
if trajectory_coords:
x = int(trajectory_coords[0] / new_width * width)
y = int(trajectory_coords[1] / new_height * height)
cv2.circle(trajectory_img, (x, y), 10, (0, 0, 255), -1)
cv2.putText(trajectory_img, trajectory_label, (x+10, y-10),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
cv2.imwrite("trajectory_output.png", trajectory_img)
print("Predicted action:", trajectory_label)
print(f"Trajectory visualization saved to trajectory_output.png")
except:
print("Error: Failed to parse coordinates or process image")
```
For more information, please refer to our [repo](https://github.com/Reallm-Labs/InfiGUI-R1). |