dpv007 commited on
Commit
2fbdc5f
·
verified ·
1 Parent(s): 3c7ea1e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B"
8
+
9
+ # Load model + processor
10
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
11
+ model = AutoModelForVision2Seq.from_pretrained(
12
+ MODEL_ID,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+
17
+ # ----------------------------
18
+ # Coordinate Extraction
19
+ # ----------------------------
20
+ def extract_coordinates(text, image_size):
21
+ """
22
+ Extracts coordinates from model output.
23
+ Supports:
24
+ - (x, y)
25
+ - [x1, y1, x2, y2]
26
+ - normalized (0.0–1.0)
27
+ """
28
+ width, height = image_size
29
+
30
+ # Match (x, y)
31
+ match = re.search(r"\(([\d\.]+),\s*([\d\.]+)\)", text)
32
+ if match:
33
+ x, y = float(match.group(1)), float(match.group(2))
34
+
35
+ # If normalized (0–1), convert to pixels
36
+ if x <= 1 and y <= 1:
37
+ x = int(x * width)
38
+ y = int(y * height)
39
+ else:
40
+ x = int(x)
41
+ y = int(y)
42
+
43
+ return (x, y)
44
+
45
+ # Match bounding box [x1, y1, x2, y2]
46
+ match_box = re.search(r"\[([\d\.,\s]+)\]", text)
47
+ if match_box:
48
+ nums = list(map(float, match_box.group(1).split(",")))
49
+ if len(nums) == 4:
50
+ x1, y1, x2, y2 = nums
51
+
52
+ # Normalize if needed
53
+ if max(nums) <= 1:
54
+ x1, x2 = int(x1 * width), int(x2 * width)
55
+ y1, y2 = int(y1 * height), int(y2 * height)
56
+ else:
57
+ x1, y1, x2, y2 = map(int, nums)
58
+
59
+ return (x1, y1, x2, y2)
60
+
61
+ return None
62
+
63
+
64
+ # ----------------------------
65
+ # Prediction Function
66
+ # ----------------------------
67
+ def predict(image, prompt):
68
+ if image is None:
69
+ return "Please upload an image.", "No coordinates"
70
+
71
+ image_pil = Image.fromarray(image).convert("RGB")
72
+ width, height = image_pil.size
73
+
74
+ inputs = processor(
75
+ images=image_pil,
76
+ text=prompt,
77
+ return_tensors="pt"
78
+ ).to(model.device)
79
+
80
+ with torch.no_grad():
81
+ output = model.generate(
82
+ **inputs,
83
+ max_new_tokens=200
84
+ )
85
+
86
+ result = processor.batch_decode(output, skip_special_tokens=True)[0]
87
+
88
+ coords = extract_coordinates(result, (width, height))
89
+
90
+ if coords:
91
+ coord_text = f"{coords} (Origin = top-left, x→right, y↓)"
92
+ else:
93
+ coord_text = "No coordinates detected"
94
+
95
+ return result, coord_text
96
+
97
+
98
+ # ----------------------------
99
+ # Gradio UI
100
+ # ----------------------------
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# 🧠 UI-TARS-1.5 GUI Agent Demo")
103
+
104
+ with gr.Row():
105
+ image_input = gr.Image(type="numpy", label="Upload Image / Screenshot")
106
+ text_input = gr.Textbox(label="Instruction / Prompt", placeholder="e.g. Click the login button")
107
+
108
+ run_btn = gr.Button("Run")
109
+
110
+ output_text = gr.Textbox(label="Model Output")
111
+ coord_output = gr.Textbox(label="Detected Coordinates")
112
+
113
+ run_btn.click(
114
+ fn=predict,
115
+ inputs=[image_input, text_input],
116
+ outputs=[output_text, coord_output]
117
+ )
118
+
119
+ demo.launch()