prithivMLmods commited on
Commit
6f12eee
·
verified ·
1 Parent(s): fcb0f85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -79
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
  import re
3
  import json
 
4
  import time
5
  import shutil
6
  import uuid
7
- import gc
8
  import tempfile
9
  import unicodedata
10
  from io import BytesIO
@@ -21,18 +21,24 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
21
  from qwen_vl_utils import process_vision_info
22
 
23
  # -----------------------------------------------------------------------------
24
- # 1. CONSTANTS & CONFIGURATION
25
  # -----------------------------------------------------------------------------
26
 
27
- # Map display names to Hugging Face Repo IDs
28
  MODEL_MAP = {
29
  "Fara-7B": "microsoft/Fara-7B",
30
- "UI-TARS-1.5-7B": "ByteDance-Seed/UI-TARS-1.5-7B"
 
31
  }
32
 
33
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
- # System Prompt
 
 
 
 
 
36
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
37
  You need to generate the next action to complete the task.
38
 
@@ -50,49 +56,38 @@ Examples:
50
  """
51
 
52
  # -----------------------------------------------------------------------------
53
- # 2. GLOBAL MODEL STATE MANAGEMENT
54
  # -----------------------------------------------------------------------------
55
 
56
- # Global variables to track the currently loaded model
57
- CURRENT_MODEL = None
58
- CURRENT_PROCESSOR = None
59
- CURRENT_MODEL_ID = None
60
-
61
- def load_model(model_key: str):
62
  """
63
- Dynamically loads the requested model.
64
- Unloads the previous model to free up GPU memory if a switch occurs.
65
  """
66
- global CURRENT_MODEL, CURRENT_PROCESSOR, CURRENT_MODEL_ID
67
 
68
- target_repo_id = MODEL_MAP[model_key]
69
 
70
- # If the requested model is already loaded, do nothing
71
- if CURRENT_MODEL is not None and CURRENT_MODEL_ID == target_repo_id:
72
- print(f"Model {model_key} is already loaded.")
73
- return
74
 
75
- print(f"--- Switching Model to {model_key} ({target_repo_id}) ---")
76
 
77
- # 1. Unload existing model to free GPU memory
78
  if CURRENT_MODEL is not None:
79
- print("Unloading current model...")
80
  del CURRENT_MODEL
81
  del CURRENT_PROCESSOR
82
  CURRENT_MODEL = None
83
  CURRENT_PROCESSOR = None
84
  gc.collect()
85
  torch.cuda.empty_cache()
86
- print("Memory cleared.")
87
 
88
- # 2. Load new model
89
  try:
90
- print(f"Loading processor for {target_repo_id}...")
91
- processor = AutoProcessor.from_pretrained(target_repo_id, trust_remote_code=True)
92
-
93
- print(f"Loading model weights for {target_repo_id}...")
94
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
95
- target_repo_id,
96
  trust_remote_code=True,
97
  torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
98
  device_map="auto" if DEVICE == "cuda" else None,
@@ -103,47 +98,51 @@ def load_model(model_key: str):
103
 
104
  model.eval()
105
 
106
- # Update global state
107
  CURRENT_MODEL = model
108
  CURRENT_PROCESSOR = processor
109
- CURRENT_MODEL_ID = target_repo_id
110
- print(f"Successfully loaded {model_key}.")
 
111
 
112
  except Exception as e:
113
- print(f"Error loading model {target_repo_id}: {e}")
114
  raise e
115
 
116
- def generate_response(messages: list[dict], max_new_tokens=512):
117
- """
118
- Runs generation using the currently loaded global model.
119
- """
120
- if CURRENT_MODEL is None or CURRENT_PROCESSOR is None:
121
- raise ValueError("No model loaded.")
122
-
123
- text = CURRENT_PROCESSOR.apply_chat_template(
124
  messages, tokenize=False, add_generation_prompt=True
125
  )
 
 
126
  image_inputs, video_inputs = process_vision_info(messages)
127
 
128
- inputs = CURRENT_PROCESSOR(
 
129
  text=[text],
130
  images=image_inputs,
131
  videos=video_inputs,
132
  padding=True,
133
  return_tensors="pt",
134
  )
135
- inputs = inputs.to(CURRENT_MODEL.device)
136
 
 
137
  with torch.no_grad():
138
- generated_ids = CURRENT_MODEL.generate(**inputs, max_new_tokens=max_new_tokens)
139
 
 
140
  generated_ids_trimmed = [
141
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
142
  ]
143
 
144
- return CURRENT_PROCESSOR.batch_decode(
145
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
146
  )[0]
 
 
147
 
148
  # -----------------------------------------------------------------------------
149
  # 3. PARSING & VISUALIZATION LOGIC
@@ -164,18 +163,25 @@ def get_navigation_prompt(task, image):
164
  ]
165
 
166
  def parse_tool_calls(response: str) -> list[dict]:
 
 
 
167
  actions = []
 
 
168
  matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
169
 
170
  for match in matches:
171
  try:
172
  json_str = match.strip()
173
  data = json.loads(json_str)
 
174
  args = data.get("arguments", {})
175
  coords = args.get("coordinate", [])
176
  action_type = args.get("action", "unknown")
177
  text_content = args.get("text", "")
178
 
 
179
  if coords and isinstance(coords, list) and len(coords) == 2:
180
  actions.append({
181
  "type": action_type,
@@ -184,12 +190,18 @@ def parse_tool_calls(response: str) -> list[dict]:
184
  "text": text_content,
185
  "raw_json": data
186
  })
187
- except Exception as e:
188
- print(f"Error parsing tool call: {e}")
 
 
 
 
 
189
 
190
  return actions
191
 
192
  def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]:
 
193
  if not actions:
194
  return None
195
 
@@ -211,11 +223,11 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
211
  'unknown': 'green'
212
  }
213
 
214
- for act in actions:
215
  x = act['x']
216
  y = act['y']
217
 
218
- # Determine if coords are normalized or absolute
219
  if x <= 1.0 and y <= 1.0 and x > 0:
220
  pixel_x = int(x * width)
221
  pixel_y = int(y * height)
@@ -226,42 +238,54 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
226
  action_type = act['type']
227
  color = colors.get(action_type, 'green')
228
 
229
- # Draw Target
230
- r = 12
231
- draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=4)
232
- draw.ellipse([pixel_x - 3, pixel_y - 3, pixel_x + 3, pixel_y + 3], fill=color)
 
 
 
 
 
 
 
 
 
233
 
234
- # Label
235
  label_text = f"{action_type}"
236
  if act['text']:
237
  label_text += f": '{act['text']}'"
238
 
239
- text_pos = (pixel_x + 15, pixel_y - 10)
 
240
  bbox = draw.textbbox(text_pos, label_text, font=font)
 
 
241
  draw.rectangle(bbox, fill="black")
242
  draw.text(text_pos, label_text, fill="white", font=font)
243
 
244
  return img_copy
245
 
246
  # -----------------------------------------------------------------------------
247
- # 4. GRADIO PROCESSING LOGIC
248
  # -----------------------------------------------------------------------------
249
 
250
  @spaces.GPU(duration=120)
251
- def process_screenshot(input_numpy_image: np.ndarray, task: str, selected_model_key: str) -> Tuple[str, Optional[Image.Image]]:
252
  if input_numpy_image is None:
253
  return "⚠️ Please upload an image first.", None
254
 
255
- # 1. Ensure correct model is loaded
256
- load_model(selected_model_key)
257
 
258
  # 2. Prepare Data
259
  input_pil_image = array_to_image(input_numpy_image)
260
  prompt = get_navigation_prompt(task, input_pil_image)
261
 
262
  # 3. Generate
263
- print(f"Generating with {selected_model_key}...")
264
- raw_response = generate_response(prompt, max_new_tokens=500)
265
  print(f"Raw Output:\n{raw_response}")
266
 
267
  # 4. Parse & Visualize
@@ -276,13 +300,17 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, selected_model_
276
  return raw_response, output_image
277
 
278
  # -----------------------------------------------------------------------------
279
- # 5. UI SETUP
280
  # -----------------------------------------------------------------------------
281
 
282
- title = "Computer Use Agent (CUA) Playground 🖥️"
283
  description = """
284
- Analyze GUI screenshots and generate action coordinates using State-of-the-Art Vision Language Models.
285
- Supported Models: **Microsoft Fara-7B** and **ByteDance UI-TARS-1.5-7B**.
 
 
 
 
286
  """
287
 
288
  custom_css = """
@@ -297,13 +325,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
297
  with gr.Column():
298
  input_image = gr.Image(label="Upload Screenshot", height=500)
299
 
300
- # Model Selector
301
- model_selector = gr.Dropdown(
302
- label="Choose CUA Model",
303
- choices=["Fara-7B", "UI-TARS-1.5-7B"],
304
- value="Fara-7B",
305
- interactive=True
306
- )
307
 
308
  task_input = gr.Textbox(
309
  label="Task Instruction",
@@ -319,20 +347,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
319
  # Wire up the button
320
  submit_btn.click(
321
  fn=process_screenshot,
322
- inputs=[input_image, task_input, model_selector],
323
  outputs=[output_text, output_image]
324
  )
325
 
 
326
  gr.Examples(
327
  examples=[
328
- ["./assets/google.png", "Search for 'Hugging Face'", "Fara-7B"],
329
- ["./assets/google.png", "Click the Sign In button", "UI-TARS-1.5-7B"],
330
  ],
331
- inputs=[input_image, task_input, model_selector],
332
  label="Quick Examples"
333
  )
334
 
335
  if __name__ == "__main__":
336
- # Pre-load the default model on startup to speed up first inference (optional)
337
- # load_model("Fara-7B")
338
  demo.queue().launch()
 
1
  import os
2
  import re
3
  import json
4
+ import gc
5
  import time
6
  import shutil
7
  import uuid
 
8
  import tempfile
9
  import unicodedata
10
  from io import BytesIO
 
21
  from qwen_vl_utils import process_vision_info
22
 
23
  # -----------------------------------------------------------------------------
24
+ # 1. CONSTANTS & SYSTEM PROMPT
25
  # -----------------------------------------------------------------------------
26
 
27
+ # Mapping UI labels to Hugging Face Model IDs
28
  MODEL_MAP = {
29
  "Fara-7B": "microsoft/Fara-7B",
30
+ # Using the official SFT checkpoint for UI-TARS
31
+ "UI-TARS-1.5-7B": "bytedance/UI-TARS-7B-SFT"
32
  }
33
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
+ # Global model state
37
+ CURRENT_MODEL = None
38
+ CURRENT_PROCESSOR = None
39
+ CURRENT_MODEL_NAME = None
40
+
41
+ # Updated System Prompt to encourage the JSON format
42
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
43
  You need to generate the next action to complete the task.
44
 
 
56
  """
57
 
58
  # -----------------------------------------------------------------------------
59
+ # 2. MODEL LOADING LOGIC
60
  # -----------------------------------------------------------------------------
61
 
62
+ def load_model_to_device(model_name: str):
 
 
 
 
 
63
  """
64
+ Loads the specified model to GPU, unloading previous models to save VRAM.
 
65
  """
66
+ global CURRENT_MODEL, CURRENT_PROCESSOR, CURRENT_MODEL_NAME
67
 
68
+ target_id = MODEL_MAP.get(model_name, model_name)
69
 
70
+ # If already loaded, skip
71
+ if CURRENT_MODEL_NAME == model_name and CURRENT_MODEL is not None:
72
+ return CURRENT_MODEL, CURRENT_PROCESSOR
 
73
 
74
+ print(f"🔄 Switching model to: {model_name} ({target_id})...")
75
 
76
+ # 1. Cleanup previous model
77
  if CURRENT_MODEL is not None:
 
78
  del CURRENT_MODEL
79
  del CURRENT_PROCESSOR
80
  CURRENT_MODEL = None
81
  CURRENT_PROCESSOR = None
82
  gc.collect()
83
  torch.cuda.empty_cache()
84
+ print("🗑️ Previous model unloaded.")
85
 
86
+ # 2. Load New Model
87
  try:
88
+ processor = AutoProcessor.from_pretrained(target_id, trust_remote_code=True)
 
 
 
89
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
+ target_id,
91
  trust_remote_code=True,
92
  torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
93
  device_map="auto" if DEVICE == "cuda" else None,
 
98
 
99
  model.eval()
100
 
 
101
  CURRENT_MODEL = model
102
  CURRENT_PROCESSOR = processor
103
+ CURRENT_MODEL_NAME = model_name
104
+ print(f" {model_name} loaded successfully.")
105
+ return model, processor
106
 
107
  except Exception as e:
108
+ print(f"Error loading {model_name}: {e}")
109
  raise e
110
 
111
+ def generate_response(model, processor, messages, max_new_tokens=512):
112
+ """Generic generation function for Qwen2.5-VL based models"""
113
+
114
+ # Apply Chat Template
115
+ text = processor.apply_chat_template(
 
 
 
116
  messages, tokenize=False, add_generation_prompt=True
117
  )
118
+
119
+ # Process Images
120
  image_inputs, video_inputs = process_vision_info(messages)
121
 
122
+ # Prepare Inputs
123
+ inputs = processor(
124
  text=[text],
125
  images=image_inputs,
126
  videos=video_inputs,
127
  padding=True,
128
  return_tensors="pt",
129
  )
130
+ inputs = inputs.to(model.device)
131
 
132
+ # Generate
133
  with torch.no_grad():
134
+ generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
135
 
136
+ # Decode
137
  generated_ids_trimmed = [
138
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
139
  ]
140
 
141
+ output_text = processor.batch_decode(
142
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
143
  )[0]
144
+
145
+ return output_text
146
 
147
  # -----------------------------------------------------------------------------
148
  # 3. PARSING & VISUALIZATION LOGIC
 
163
  ]
164
 
165
  def parse_tool_calls(response: str) -> list[dict]:
166
+ """
167
+ Parses the <tool_call>{JSON}</tool_call> format.
168
+ """
169
  actions = []
170
+
171
+ # Regex to find content between <tool_call> tags
172
  matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
173
 
174
  for match in matches:
175
  try:
176
  json_str = match.strip()
177
  data = json.loads(json_str)
178
+
179
  args = data.get("arguments", {})
180
  coords = args.get("coordinate", [])
181
  action_type = args.get("action", "unknown")
182
  text_content = args.get("text", "")
183
 
184
+ # Check if coords exist and are a list of length 2
185
  if coords and isinstance(coords, list) and len(coords) == 2:
186
  actions.append({
187
  "type": action_type,
 
190
  "text": text_content,
191
  "raw_json": data
192
  })
193
+ print(f"Parsed Action: {action_type} at {coords}")
194
+ else:
195
+ # Some actions like 'scroll' might not have coordinates in some models
196
+ print(f"Non-coordinate action or invalid: {json_str}")
197
+
198
+ except json.JSONDecodeError as e:
199
+ print(f"Failed to parse JSON: {e}")
200
 
201
  return actions
202
 
203
  def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]:
204
+ """Draws markers on the image based on parsed pixel coordinates."""
205
  if not actions:
206
  return None
207
 
 
223
  'unknown': 'green'
224
  }
225
 
226
+ for i, act in enumerate(actions):
227
  x = act['x']
228
  y = act['y']
229
 
230
+ # Check if Normalized (0.0 - 1.0) or Absolute (Pixels > 1.0)
231
  if x <= 1.0 and y <= 1.0 and x > 0:
232
  pixel_x = int(x * width)
233
  pixel_y = int(y * height)
 
238
  action_type = act['type']
239
  color = colors.get(action_type, 'green')
240
 
241
+ # Draw Circle Target
242
+ r = 15 # Radius
243
+ draw.ellipse(
244
+ [pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r],
245
+ outline=color,
246
+ width=4
247
+ )
248
+
249
+ # Draw Center Dot
250
+ draw.ellipse(
251
+ [pixel_x - 4, pixel_y - 4, pixel_x + 4, pixel_y + 4],
252
+ fill=color
253
+ )
254
 
255
+ # Label Text
256
  label_text = f"{action_type}"
257
  if act['text']:
258
  label_text += f": '{act['text']}'"
259
 
260
+ # Text Background
261
+ text_pos = (pixel_x + 18, pixel_y - 12)
262
  bbox = draw.textbbox(text_pos, label_text, font=font)
263
+ # Add padding to bbox
264
+ bbox = (bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2)
265
  draw.rectangle(bbox, fill="black")
266
  draw.text(text_pos, label_text, fill="white", font=font)
267
 
268
  return img_copy
269
 
270
  # -----------------------------------------------------------------------------
271
+ # 4. GRADIO LOGIC
272
  # -----------------------------------------------------------------------------
273
 
274
  @spaces.GPU(duration=120)
275
+ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: str) -> Tuple[str, Optional[Image.Image]]:
276
  if input_numpy_image is None:
277
  return "⚠️ Please upload an image first.", None
278
 
279
+ # 1. Load Requested Model (Switching if necessary)
280
+ model, processor = load_model_to_device(model_choice)
281
 
282
  # 2. Prepare Data
283
  input_pil_image = array_to_image(input_numpy_image)
284
  prompt = get_navigation_prompt(task, input_pil_image)
285
 
286
  # 3. Generate
287
+ print(f"Generating response using {model_choice}...")
288
+ raw_response = generate_response(model, processor, prompt, max_new_tokens=512)
289
  print(f"Raw Output:\n{raw_response}")
290
 
291
  # 4. Parse & Visualize
 
300
  return raw_response, output_image
301
 
302
  # -----------------------------------------------------------------------------
303
+ # 5. GRADIO UI SETUP
304
  # -----------------------------------------------------------------------------
305
 
306
+ title = "CUA GUI Agent 🖥️"
307
  description = """
308
+ **Computer Use Agent (CUA)** Demo.
309
+ Upload a screenshot and provide a task instruction. The model will analyze the UI and output the precise coordinates and actions required.
310
+
311
+ **Models Supported:**
312
+ * **Fara-7B**: Microsoft's GUI agent model.
313
+ * **UI-TARS-1.5-7B**: ByteDance's GUI agent model.
314
  """
315
 
316
  custom_css = """
 
325
  with gr.Column():
326
  input_image = gr.Image(label="Upload Screenshot", height=500)
327
 
328
+ with gr.Row():
329
+ model_choice = gr.Dropdown(
330
+ label="Choose CUA Model",
331
+ choices=list(MODEL_MAP.keys()),
332
+ value="Fara-7B",
333
+ interactive=True
334
+ )
335
 
336
  task_input = gr.Textbox(
337
  label="Task Instruction",
 
347
  # Wire up the button
348
  submit_btn.click(
349
  fn=process_screenshot,
350
+ inputs=[input_image, task_input, model_choice],
351
  outputs=[output_text, output_image]
352
  )
353
 
354
+ # Example for quick testing
355
  gr.Examples(
356
  examples=[
357
+ ["./assets/google.png", "Search for 'Hugging Face'", "Fara-7B"],
 
358
  ],
359
+ inputs=[input_image, task_input, model_choice],
360
  label="Quick Examples"
361
  )
362
 
363
  if __name__ == "__main__":
364
+ # Pre-load default model to speed up first request if memory allows,
365
+ # but strictly loading on GPU request is safer for Spaces.
366
  demo.queue().launch()