prithivMLmods commited on
Commit
73b010f
·
verified ·
1 Parent(s): 6aaf210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -127
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import os
2
  import re
3
  import json
4
- import gc
5
  import time
 
 
 
6
  import unicodedata
7
  from io import BytesIO
8
- from typing import Tuple, Optional, List, Dict, Any
9
 
10
  import gradio as gr
11
  import numpy as np
@@ -14,17 +16,87 @@ import spaces
14
  from PIL import Image, ImageDraw, ImageFont
15
 
16
  # Transformers & Qwen Utils
17
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
 
 
18
  from qwen_vl_utils import process_vision_info
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # -----------------------------------------------------------------------------
21
- # 1. CONSTANTS & PROMPTS
22
  # -----------------------------------------------------------------------------
23
 
24
- # Device Configuration
25
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
- # System Prompt (Forces models to output parseable JSON)
28
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
29
  You need to generate the next action to complete the task.
30
 
@@ -41,81 +113,31 @@ Examples:
41
  </tool_call>
42
  """
43
 
44
- # -----------------------------------------------------------------------------
45
- # 2. GLOBAL STATE & MODEL MANAGEMENT
46
- # -----------------------------------------------------------------------------
47
-
48
- # We use a global dictionary to hold the currently loaded model to avoid reloading if not changed
49
- current_model_state = {
50
- "model": None,
51
- "processor": None,
52
- "name": None
53
- }
54
-
55
- def load_fara_model():
56
- """Loads Microsoft Fara-7B to CUDA"""
57
- print("🔄 Loading Fara-7B...")
58
- MODEL_ID_V = "microsoft/Fara-7B"
59
-
60
- processor = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
61
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
- MODEL_ID_V,
63
- trust_remote_code=True,
64
- torch_dtype=torch.float16 # As requested
65
- ).to(DEVICE).eval()
66
-
67
- return model, processor
68
-
69
- def load_uitars_model():
70
- """Loads UI-TARS-1.5-7B to CUDA"""
71
- print("🔄 Loading UI-TARS...")
72
- # Note: Using the official SFT repo as the specific ID provided in snippet might be private/incorrect
73
- MODEL_ID_X = "bytedance/UI-TARS-7B-SFT"
74
-
75
- processor = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
76
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
77
- MODEL_ID_X,
78
- trust_remote_code=True,
79
- torch_dtype=torch.bfloat16, # As requested
80
- ).to(DEVICE).eval()
81
-
82
- return model, processor
83
-
84
- def get_model_pipeline(model_choice: str):
85
- """
86
- Manages VRAM: Unloads old model, loads new model based on user choice.
87
- """
88
- global current_model_state
89
-
90
- # If the requested model is already loaded, return it
91
- if current_model_state["name"] == model_choice and current_model_state["model"] is not None:
92
- return current_model_state["model"], current_model_state["processor"]
93
-
94
- # Otherwise, clear VRAM first
95
- if current_model_state["model"] is not None:
96
- print("🗑️ Unloading previous model to free VRAM...")
97
- del current_model_state["model"]
98
- del current_model_state["processor"]
99
- current_model_state["model"] = None
100
- current_model_state["processor"] = None
101
- gc.collect()
102
- torch.cuda.empty_cache()
103
-
104
- # Load the requested model
105
- if model_choice == "Fara-7B":
106
- model, processor = load_fara_model()
107
- else:
108
- model, processor = load_uitars_model()
109
-
110
- # Update state
111
- current_model_state["model"] = model
112
- current_model_state["processor"] = processor
113
- current_model_state["name"] = model_choice
114
-
115
- return model, processor
116
 
117
  # -----------------------------------------------------------------------------
118
- # 3. UTILS: IMAGE & PARSING
119
  # -----------------------------------------------------------------------------
120
 
121
  def array_to_image(image_array: np.ndarray) -> Image.Image:
@@ -149,7 +171,6 @@ def parse_tool_calls(response: str) -> list[dict]:
149
  action_type = args.get("action", "unknown")
150
  text_content = args.get("text", "")
151
 
152
- # Basic validation
153
  if coords and isinstance(coords, list) and len(coords) == 2:
154
  actions.append({
155
  "type": action_type,
@@ -160,8 +181,12 @@ def parse_tool_calls(response: str) -> list[dict]:
160
  })
161
  print(f"Parsed Action: {action_type} at {coords}")
162
  else:
163
- # Some actions (like key press) might not have coords
164
- print(f"Action parsed without coords: {action_type}")
 
 
 
 
165
 
166
  except json.JSONDecodeError:
167
  print(f"Failed to parse JSON: {match}")
@@ -169,7 +194,7 @@ def parse_tool_calls(response: str) -> list[dict]:
169
  return actions
170
 
171
  def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]:
172
- """Draws visual markers on the screenshot."""
173
  if not actions:
174
  return None
175
 
@@ -183,18 +208,23 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
183
  font = None
184
 
185
  colors = {
186
- 'click': 'red', 'left_click': 'red',
187
  'type': 'blue',
 
 
188
  'right_click': 'purple',
189
  'double_click': 'orange',
190
  'unknown': 'green'
191
  }
192
 
193
  for act in actions:
 
 
 
 
194
  x = act['x']
195
  y = act['y']
196
 
197
- # Handle normalized (0-1) vs pixel coordinates
198
  if x <= 1.0 and y <= 1.0 and x > 0:
199
  pixel_x = int(x * width)
200
  pixel_y = int(y * height)
@@ -205,46 +235,51 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
205
  action_type = act['type']
206
  color = colors.get(action_type, 'green')
207
 
208
- # Draw Target
209
- r = 15
210
  draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=4)
211
- draw.ellipse([pixel_x - 4, pixel_y - 4, pixel_x + 4, pixel_y + 4], fill=color)
212
 
213
- # Draw Label
214
  label_text = f"{action_type}"
215
  if act['text']:
216
  label_text += f": '{act['text']}'"
217
 
218
- text_pos = (pixel_x + 18, pixel_y - 12)
219
  bbox = draw.textbbox(text_pos, label_text, font=font)
220
- draw.rectangle((bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2), fill="black")
221
  draw.text(text_pos, label_text, fill="white", font=font)
222
 
223
  return img_copy
224
 
225
  # -----------------------------------------------------------------------------
226
- # 4. GRADIO LOGIC (ZERO-GPU ENABLED)
227
  # -----------------------------------------------------------------------------
228
 
229
- @spaces.GPU(duration=120)
230
  def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: str) -> Tuple[str, Optional[Image.Image]]:
231
  if input_numpy_image is None:
232
  return "⚠️ Please upload an image first.", None
233
 
234
- # 1. Load the specific model requested (Fara or UI-TARS) to CUDA
235
- model, processor = get_model_pipeline(model_choice)
 
 
 
 
 
 
 
236
 
237
  # 2. Prepare Data
238
  input_pil_image = array_to_image(input_numpy_image)
239
- prompt_messages = get_navigation_prompt(task, input_pil_image)
240
 
241
  # 3. Generate
242
- print(f"Generating response using {model_choice}...")
243
-
244
  text_prompts = processor.apply_chat_template(
245
- prompt_messages, tokenize=False, add_generation_prompt=True
246
  )
247
- image_inputs, video_inputs = process_vision_info(prompt_messages)
248
 
249
  inputs = processor(
250
  text=[text_prompts],
@@ -253,8 +288,9 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
253
  padding=True,
254
  return_tensors="pt",
255
  )
256
- inputs = inputs.to(model.device)
257
 
 
258
  with torch.no_grad():
259
  generated_ids = model.generate(**inputs, max_new_tokens=512)
260
 
@@ -280,35 +316,21 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
280
  return raw_response, output_image
281
 
282
  # -----------------------------------------------------------------------------
283
- # 5. GRADIO UI SETUP
284
  # -----------------------------------------------------------------------------
285
 
286
- title = "CUA GUI Agent 🖥️"
287
- description = """
288
- **Computer Use Agent (CUA)** Demo.
289
- Upload a screenshot and provide a task instruction. The model will analyze the UI and output the precise coordinates and actions required.
290
-
291
- **Models Supported:**
292
- * **Fara-7B**: Microsoft's GUI agent model.
293
- * **UI-TARS-1.5-7B**: ByteDance's GUI agent model.
294
- """
295
-
296
- custom_css = """
297
- #out_img { height: 600px; object-fit: contain; }
298
- """
299
-
300
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
301
- gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
302
- gr.Markdown(description)
303
 
304
  with gr.Row():
305
- with gr.Column():
306
  input_image = gr.Image(label="Upload Screenshot", height=500)
307
 
308
  with gr.Row():
309
- model_choice = gr.Dropdown(
310
- label="Choose CUA Model",
311
  choices=["Fara-7B", "UI-TARS-1.5-7B"],
 
312
  value="Fara-7B",
313
  interactive=True
314
  )
@@ -320,9 +342,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
320
  )
321
  submit_btn = gr.Button("Analyze UI & Generate Action", variant="primary")
322
 
323
- with gr.Column():
324
  output_image = gr.Image(label="Visualized Action Points", elem_id="out_img", height=500)
325
- output_text = gr.Textbox(label="Raw Model Output", lines=8, show_copy_button=True)
326
 
327
  # Wire up the button
328
  submit_btn.click(
@@ -331,7 +353,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
331
  outputs=[output_text, output_image]
332
  )
333
 
334
- # Example for quick testing
335
  gr.Examples(
336
  examples=[
337
  ["./assets/google.png", "Search for 'Hugging Face'", "Fara-7B"],
@@ -341,4 +363,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
341
  )
342
 
343
  if __name__ == "__main__":
344
- demo.queue().launch()
 
1
  import os
2
  import re
3
  import json
 
4
  import time
5
+ import shutil
6
+ import uuid
7
+ import tempfile
8
  import unicodedata
9
  from io import BytesIO
10
+ from typing import Tuple, Optional, List, Iterable
11
 
12
  import gradio as gr
13
  import numpy as np
 
16
  from PIL import Image, ImageDraw, ImageFont
17
 
18
  # Transformers & Qwen Utils
19
+ from transformers import (
20
+ Qwen2_5_VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
  from qwen_vl_utils import process_vision_info
24
 
25
+ # Gradio Theme Utils
26
+ from gradio.themes import Soft
27
+ from gradio.themes.utils import colors, fonts, sizes
28
+
29
+ colors.steel_blue = colors.Color(
30
+ name="steel_blue",
31
+ c50="#EBF3F8",
32
+ c100="#D3E5F0",
33
+ c200="#A8CCE1",
34
+ c300="#7DB3D2",
35
+ c400="#529AC3",
36
+ c500="#4682B4",
37
+ c600="#3E72A0",
38
+ c700="#36638C",
39
+ c800="#2E5378",
40
+ c900="#264364",
41
+ c950="#1E3450",
42
+ )
43
+
44
+ class SteelBlueTheme(Soft):
45
+ def __init__(
46
+ self,
47
+ *,
48
+ primary_hue: colors.Color | str = colors.gray,
49
+ secondary_hue: colors.Color | str = colors.steel_blue,
50
+ neutral_hue: colors.Color | str = colors.slate,
51
+ text_size: sizes.Size | str = sizes.text_lg,
52
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
53
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
54
+ ),
55
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
56
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
57
+ ),
58
+ ):
59
+ super().__init__(
60
+ primary_hue=primary_hue,
61
+ secondary_hue=secondary_hue,
62
+ neutral_hue=neutral_hue,
63
+ text_size=text_size,
64
+ font=font,
65
+ font_mono=font_mono,
66
+ )
67
+ super().set(
68
+ background_fill_primary="*primary_50",
69
+ background_fill_primary_dark="*primary_900",
70
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
71
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
72
+ button_primary_text_color="white",
73
+ button_primary_text_color_hover="white",
74
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
75
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
76
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
77
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
78
+ block_title_text_weight="600",
79
+ block_border_width="3px",
80
+ block_shadow="*shadow_drop_lg",
81
+ button_primary_shadow="*shadow_drop_lg",
82
+ button_large_padding="11px",
83
+ )
84
+
85
+ steel_blue_theme = SteelBlueTheme()
86
+
87
+ css = """
88
+ #main-title h1 { font-size: 2.3em !important; }
89
+ #out_img { height: 600px; object-fit: contain; }
90
+ """
91
+
92
  # -----------------------------------------------------------------------------
93
+ # 2. MODEL LOADING (Global Setup)
94
  # -----------------------------------------------------------------------------
95
 
96
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
97
+ print(f"Using device: {device}")
98
 
99
+ # System Prompt
100
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
101
  You need to generate the next action to complete the task.
102
 
 
113
  </tool_call>
114
  """
115
 
116
+ # Load Fara-7B
117
+ print("Loading Fara-7B...")
118
+ MODEL_ID_V = "microsoft/Fara-7B"
119
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
120
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
121
+ MODEL_ID_V,
122
+ trust_remote_code=True,
123
+ torch_dtype=torch.bfloat16
124
+ ).to(device).eval()
125
+
126
+ # Load UI-TARS-1.5-7B
127
+ print("Loading UI-TARS-1.5-7B...")
128
+ # Note: Using the official SFT repo. Adjust if you have a specific private repo.
129
+ MODEL_ID_X = "bytedance/UI-TARS-7B"
130
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
131
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
132
+ MODEL_ID_X,
133
+ trust_remote_code=True,
134
+ torch_dtype=torch.bfloat16,
135
+ ).to(device).eval()
136
+
137
+ print("✅ All Models Loaded Successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # -----------------------------------------------------------------------------
140
+ # 3. UTILS: IMAGE, PARSING, VISUALIZATION
141
  # -----------------------------------------------------------------------------
142
 
143
  def array_to_image(image_array: np.ndarray) -> Image.Image:
 
171
  action_type = args.get("action", "unknown")
172
  text_content = args.get("text", "")
173
 
 
174
  if coords and isinstance(coords, list) and len(coords) == 2:
175
  actions.append({
176
  "type": action_type,
 
181
  })
182
  print(f"Parsed Action: {action_type} at {coords}")
183
  else:
184
+ # Handle actions without coordinates (like pressing enter generally)
185
+ actions.append({
186
+ "type": action_type,
187
+ "text": text_content,
188
+ "raw_json": data
189
+ })
190
 
191
  except json.JSONDecodeError:
192
  print(f"Failed to parse JSON: {match}")
 
194
  return actions
195
 
196
  def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]:
197
+ """Draws markers on the image based on parsed pixel coordinates."""
198
  if not actions:
199
  return None
200
 
 
208
  font = None
209
 
210
  colors = {
 
211
  'type': 'blue',
212
+ 'click': 'red',
213
+ 'left_click': 'red',
214
  'right_click': 'purple',
215
  'double_click': 'orange',
216
  'unknown': 'green'
217
  }
218
 
219
  for act in actions:
220
+ # Only draw if coordinates exist
221
+ if 'x' not in act or 'y' not in act:
222
+ continue
223
+
224
  x = act['x']
225
  y = act['y']
226
 
227
+ # Check if Normalized (0.0 - 1.0) or Absolute (Pixels > 1.0)
228
  if x <= 1.0 and y <= 1.0 and x > 0:
229
  pixel_x = int(x * width)
230
  pixel_y = int(y * height)
 
235
  action_type = act['type']
236
  color = colors.get(action_type, 'green')
237
 
238
+ # Draw Circle Target
239
+ r = 12
240
  draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=4)
241
+ draw.ellipse([pixel_x - 3, pixel_y - 3, pixel_x + 3, pixel_y + 3], fill=color)
242
 
243
+ # Draw Label text
244
  label_text = f"{action_type}"
245
  if act['text']:
246
  label_text += f": '{act['text']}'"
247
 
248
+ text_pos = (pixel_x + 15, pixel_y - 10)
249
  bbox = draw.textbbox(text_pos, label_text, font=font)
250
+ draw.rectangle(bbox, fill="black")
251
  draw.text(text_pos, label_text, fill="white", font=font)
252
 
253
  return img_copy
254
 
255
  # -----------------------------------------------------------------------------
256
+ # 4. PROCESSING LOGIC
257
  # -----------------------------------------------------------------------------
258
 
259
+ @spaces.GPU
260
  def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: str) -> Tuple[str, Optional[Image.Image]]:
261
  if input_numpy_image is None:
262
  return "⚠️ Please upload an image first.", None
263
 
264
+ # 1. Select Model
265
+ if model_choice == "Fara-7B":
266
+ model = model_v
267
+ processor = processor_v
268
+ elif model_choice == "UI-TARS-1.5-7B":
269
+ model = model_x
270
+ processor = processor_x
271
+ else:
272
+ return "Invalid model selection", None
273
 
274
  # 2. Prepare Data
275
  input_pil_image = array_to_image(input_numpy_image)
276
+ prompt = get_navigation_prompt(task, input_pil_image)
277
 
278
  # 3. Generate
 
 
279
  text_prompts = processor.apply_chat_template(
280
+ prompt, tokenize=False, add_generation_prompt=True
281
  )
282
+ image_inputs, video_inputs = process_vision_info(prompt)
283
 
284
  inputs = processor(
285
  text=[text_prompts],
 
288
  padding=True,
289
  return_tensors="pt",
290
  )
291
+ inputs = inputs.to(device)
292
 
293
+ print(f"Generating with {model_choice}...")
294
  with torch.no_grad():
295
  generated_ids = model.generate(**inputs, max_new_tokens=512)
296
 
 
316
  return raw_response, output_image
317
 
318
  # -----------------------------------------------------------------------------
319
+ # 5. GRADIO UI
320
  # -----------------------------------------------------------------------------
321
 
322
+ with gr.Blocks(theme=steel_blue_theme, css=css) as demo:
323
+ gr.Markdown("# **CUA GUI Agent 🖥️**", elem_id="main-title")
324
+ gr.Markdown("Upload a screenshot, select a model, and provide a task. The model will determine the precise UI coordinates and actions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  with gr.Row():
327
+ with gr.Column(scale=2):
328
  input_image = gr.Image(label="Upload Screenshot", height=500)
329
 
330
  with gr.Row():
331
+ model_choice = gr.Radio(
 
332
  choices=["Fara-7B", "UI-TARS-1.5-7B"],
333
+ label="Select Model",
334
  value="Fara-7B",
335
  interactive=True
336
  )
 
342
  )
343
  submit_btn = gr.Button("Analyze UI & Generate Action", variant="primary")
344
 
345
+ with gr.Column(scale=3):
346
  output_image = gr.Image(label="Visualized Action Points", elem_id="out_img", height=500)
347
+ output_text = gr.Textbox(label="Raw Model Output (JSON)", lines=8, show_copy_button=True)
348
 
349
  # Wire up the button
350
  submit_btn.click(
 
353
  outputs=[output_text, output_image]
354
  )
355
 
356
+ # Examples
357
  gr.Examples(
358
  examples=[
359
  ["./assets/google.png", "Search for 'Hugging Face'", "Fara-7B"],
 
363
  )
364
 
365
  if __name__ == "__main__":
366
+ demo.queue(max_size=20).launch(show_error=True)