prithivMLmods commited on
Commit
5c21f23
·
verified ·
1 Parent(s): 7d0d550

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -73
app.py CHANGED
@@ -6,6 +6,7 @@ import shutil
6
  import uuid
7
  import tempfile
8
  import unicodedata
 
9
  from io import BytesIO
10
  from typing import Tuple, Optional, List, Dict, Any
11
 
@@ -23,10 +24,15 @@ from qwen_vl_utils import process_vision_info
23
  # 1. CONSTANTS & SYSTEM PROMPT
24
  # -----------------------------------------------------------------------------
25
 
26
- MODEL_ID = "microsoft/Fara-7B"
 
 
 
 
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- # Updated System Prompt to encourage the JSON format the model prefers
30
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
31
  You need to generate the next action to complete the task.
32
 
@@ -44,31 +50,59 @@ Examples:
44
  """
45
 
46
  # -----------------------------------------------------------------------------
47
- # 2. MODEL DEFINITION
48
  # -----------------------------------------------------------------------------
49
 
50
- class FaraTransformersModel:
51
- def __init__(self, model_id: str, to_device: str = "cuda"):
52
- print(f"Loading {model_id} on {to_device}...")
53
- self.model_id = model_id
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
57
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
58
  model_id,
59
  trust_remote_code=True,
60
- torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
61
- device_map="auto" if to_device == "cuda" else None,
62
  )
63
- if to_device == "cpu":
64
  self.model.to("cpu")
65
  self.model.eval()
66
- print("Model loaded successfully.")
 
67
  except Exception as e:
68
- print(f"Error loading Fara: {e}")
69
  raise e
70
 
71
- def generate(self, messages: list[dict], max_new_tokens=512):
 
 
 
 
72
  text = self.processor.apply_chat_template(
73
  messages, tokenize=False, add_generation_prompt=True
74
  )
@@ -83,6 +117,7 @@ class FaraTransformersModel:
83
  )
84
  inputs = inputs.to(self.model.device)
85
 
 
86
  with torch.no_grad():
87
  generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
88
 
@@ -94,12 +129,11 @@ class FaraTransformersModel:
94
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
95
  )[0]
96
 
97
- # Initialize Model
98
- print(f"Initializing model class for {MODEL_ID}...")
99
- fara_model = FaraTransformersModel(MODEL_ID, to_device=DEVICE)
100
 
101
  # -----------------------------------------------------------------------------
102
- # 3. PARSING & VISUALIZATION LOGIC (UPDATED)
103
  # -----------------------------------------------------------------------------
104
 
105
  def array_to_image(image_array: np.ndarray) -> Image.Image:
@@ -118,25 +152,17 @@ def get_navigation_prompt(task, image):
118
 
119
  def parse_tool_calls(response: str) -> list[dict]:
120
  """
121
- Parses the <tool_call>{JSON}</tool_call> format specifically.
122
- Extracts coordinates and action types.
123
  """
124
  actions = []
125
 
126
- # Regex to find content between <tool_call> tags
127
- # re.DOTALL allows matching across newlines
128
- matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
129
-
130
- for match in matches:
131
  try:
132
- # Clean up the string just in case
133
- json_str = match.strip()
134
- data = json.loads(json_str)
135
-
136
- # Access the 'arguments' dictionary
137
  args = data.get("arguments", {})
138
-
139
- # Extract coordinates: Expecting list like [399, 496]
140
  coords = args.get("coordinate", [])
141
  action_type = args.get("action", "unknown")
142
  text_content = args.get("text", "")
@@ -147,16 +173,23 @@ def parse_tool_calls(response: str) -> list[dict]:
147
  "x": float(coords[0]),
148
  "y": float(coords[1]),
149
  "text": text_content,
150
- "raw_json": data
151
  })
152
- print(f"Parsed Action: {action_type} at {coords}")
153
- else:
154
- print(f"No valid coordinates found in tool call: {json_str}")
155
-
156
- except json.JSONDecodeError as e:
157
- print(f"Failed to parse JSON in tool call: {e}\nString was: {match}")
158
- except Exception as e:
159
- print(f"Unexpected error parsing tool call: {e}")
 
 
 
 
 
 
 
160
 
161
  return actions
162
 
@@ -169,7 +202,6 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
169
  draw = ImageDraw.Draw(img_copy)
170
  width, height = img_copy.size
171
 
172
- # Try loading font
173
  try:
174
  font = ImageFont.load_default()
175
  except:
@@ -184,49 +216,46 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
184
  'unknown': 'green'
185
  }
186
 
187
- for i, act in enumerate(actions):
188
  x = act['x']
189
  y = act['y']
190
 
191
- # Check if Normalized (0.0 - 1.0) or Absolute (Pixels > 1.0)
192
- # The logs showed [399, 496], so these are pixels.
193
- # However, to be safe, we check.
194
  if x <= 1.0 and y <= 1.0 and x > 0:
195
- # It's normalized, convert to pixels
196
  pixel_x = int(x * width)
197
  pixel_y = int(y * height)
198
  else:
199
- # It's absolute pixels
200
  pixel_x = int(x)
201
  pixel_y = int(y)
202
 
203
  action_type = act['type']
204
  color = colors.get(action_type, 'green')
205
 
206
- # Draw Circle Target
207
- r = 12 # Radius
208
  draw.ellipse(
209
  [pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r],
210
  outline=color,
211
  width=4
212
  )
213
-
214
- # Draw Center Dot
215
  draw.ellipse(
216
  [pixel_x - 3, pixel_y - 3, pixel_x + 3, pixel_y + 3],
217
  fill=color
218
  )
219
 
220
- # Draw Label text
221
  label_text = f"{action_type}"
222
  if act['text']:
223
  label_text += f": '{act['text']}'"
224
 
225
- # Draw text background for readability
226
  text_pos = (pixel_x + 15, pixel_y - 10)
227
- bbox = draw.textbbox(text_pos, label_text, font=font)
228
- draw.rectangle(bbox, fill="black")
229
- draw.text(text_pos, label_text, fill="white", font=font)
 
 
 
 
230
 
231
  return img_copy
232
 
@@ -234,29 +263,30 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
234
  # 4. GRADIO LOGIC
235
  # -----------------------------------------------------------------------------
236
 
237
- @spaces.GPU(duration=60)
238
- def process_screenshot(input_numpy_image: np.ndarray, task: str) -> Tuple[str, Optional[Image.Image]]:
239
  if input_numpy_image is None:
240
  return "⚠️ Please upload an image first.", None
241
 
242
  # Convert to PIL
243
  input_pil_image = array_to_image(input_numpy_image)
244
 
245
- # 1. Build Prompt
246
  prompt = get_navigation_prompt(task, input_pil_image)
247
 
248
- # 2. Generate Response
249
- if fara_model is None:
250
- raise ValueError("Model not loaded")
251
-
252
- print("Generating response...")
253
- raw_response = fara_model.generate(prompt, max_new_tokens=500)
 
254
  print(f"Raw Output:\n{raw_response}")
255
 
256
- # 3. Parse Actions
257
  actions = parse_tool_calls(raw_response)
258
 
259
- # 4. Visualize
260
  output_image = input_pil_image
261
  if actions:
262
  visualized = create_localized_image(input_pil_image, actions)
@@ -269,10 +299,10 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str) -> Tuple[str, O
269
  # 5. GRADIO UI SETUP
270
  # -----------------------------------------------------------------------------
271
 
272
- title = "Fara-7B GUI Operator 🖥️"
273
  description = """
274
- This demo uses **microsoft/Fara-7B** to understand GUI screenshots.
275
- It generates action coordinates which are then parsed and plotted on the image.
276
  """
277
 
278
  custom_css = """
@@ -285,6 +315,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
285
 
286
  with gr.Row():
287
  with gr.Column():
 
 
 
 
 
 
 
 
288
  input_image = gr.Image(label="Upload Screenshot", height=500)
289
  task_input = gr.Textbox(
290
  label="Task Instruction",
@@ -300,16 +338,16 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
300
  # Wire up the button
301
  submit_btn.click(
302
  fn=process_screenshot,
303
- inputs=[input_image, task_input],
304
  outputs=[output_text, output_image]
305
  )
306
 
307
  # Example for quick testing
308
  gr.Examples(
309
  examples=[
310
- ["./assets/google.png", "Search for 'Hugging Face'"],
311
  ],
312
- inputs=[input_image, task_input],
313
  label="Quick Examples"
314
  )
315
 
 
6
  import uuid
7
  import tempfile
8
  import unicodedata
9
+ import gc
10
  from io import BytesIO
11
  from typing import Tuple, Optional, List, Dict, Any
12
 
 
24
  # 1. CONSTANTS & SYSTEM PROMPT
25
  # -----------------------------------------------------------------------------
26
 
27
+ # Available Models
28
+ MODELS = {
29
+ "Fara-7B": "microsoft/Fara-7B",
30
+ "UI-TARS-1.5-7B": "ByteDance-Seed/UI-TARS-1.5-7B"
31
+ }
32
+
33
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
+ # System Prompt asking for JSON format
36
  OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
37
  You need to generate the next action to complete the task.
38
 
 
50
  """
51
 
52
  # -----------------------------------------------------------------------------
53
+ # 2. MODEL MANAGEMENT
54
  # -----------------------------------------------------------------------------
55
 
56
+ class ModelManager:
57
+ def __init__(self):
58
+ self.current_model_id = None
59
+ self.model = None
60
+ self.processor = None
61
+
62
+ def load_model(self, model_key):
63
+ model_id = MODELS.get(model_key)
64
+ if not model_id:
65
+ raise ValueError(f"Unknown model: {model_key}")
66
+
67
+ # If already loaded, skip
68
+ if self.current_model_id == model_id and self.model is not None:
69
+ return
70
+
71
+ print(f"--- Swapping model to {model_key} ({model_id}) ---")
72
 
73
+ # Unload previous model to save VRAM
74
+ if self.model is not None:
75
+ del self.model
76
+ del self.processor
77
+ self.model = None
78
+ self.processor = None
79
+ gc.collect()
80
+ torch.cuda.empty_cache()
81
+ print("Previous model unloaded.")
82
+
83
+ print(f"Loading {model_id}...")
84
  try:
85
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
86
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
87
  model_id,
88
  trust_remote_code=True,
89
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
90
+ device_map="auto" if DEVICE == "cuda" else None,
91
  )
92
+ if DEVICE == "cpu":
93
  self.model.to("cpu")
94
  self.model.eval()
95
+ self.current_model_id = model_id
96
+ print(f"Successfully loaded {model_key}")
97
  except Exception as e:
98
+ print(f"Error loading model {model_id}: {e}")
99
  raise e
100
 
101
+ def generate(self, model_key, messages, max_new_tokens=512):
102
+ # Ensure correct model is loaded
103
+ self.load_model(model_key)
104
+
105
+ # Prepare inputs
106
  text = self.processor.apply_chat_template(
107
  messages, tokenize=False, add_generation_prompt=True
108
  )
 
117
  )
118
  inputs = inputs.to(self.model.device)
119
 
120
+ # Generate
121
  with torch.no_grad():
122
  generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
123
 
 
129
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
130
  )[0]
131
 
132
+ # Global instance
133
+ model_manager = ModelManager()
 
134
 
135
  # -----------------------------------------------------------------------------
136
+ # 3. PARSING & VISUALIZATION LOGIC
137
  # -----------------------------------------------------------------------------
138
 
139
  def array_to_image(image_array: np.ndarray) -> Image.Image:
 
152
 
153
  def parse_tool_calls(response: str) -> list[dict]:
154
  """
155
+ Parses <tool_call>{JSON}</tool_call> tags.
156
+ Also attempts fallback regex for plain coordinate output just in case.
157
  """
158
  actions = []
159
 
160
+ # 1. Try Specific JSON Tool Call
161
+ json_matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
162
+ for match in json_matches:
 
 
163
  try:
164
+ data = json.loads(match.strip())
 
 
 
 
165
  args = data.get("arguments", {})
 
 
166
  coords = args.get("coordinate", [])
167
  action_type = args.get("action", "unknown")
168
  text_content = args.get("text", "")
 
173
  "x": float(coords[0]),
174
  "y": float(coords[1]),
175
  "text": text_content,
176
+ "source": "json"
177
  })
178
+ except:
179
+ pass
180
+
181
+ # 2. Fallback: Search for any [x, y] or (x, y) pattern if JSON parsing yielded nothing
182
+ if not actions:
183
+ # Regex for [123, 456] or (123, 456)
184
+ coord_matches = re.findall(r"[\[\(](\d+(?:\.\d+)?),\s*(\d+(?:\.\d+)?)[\]\)]", response)
185
+ for x, y in coord_matches:
186
+ actions.append({
187
+ "type": "click", # Assume click for raw coords
188
+ "x": float(x),
189
+ "y": float(y),
190
+ "text": "",
191
+ "source": "regex"
192
+ })
193
 
194
  return actions
195
 
 
202
  draw = ImageDraw.Draw(img_copy)
203
  width, height = img_copy.size
204
 
 
205
  try:
206
  font = ImageFont.load_default()
207
  except:
 
216
  'unknown': 'green'
217
  }
218
 
219
+ for act in actions:
220
  x = act['x']
221
  y = act['y']
222
 
223
+ # Coordinate Normalization check
 
 
224
  if x <= 1.0 and y <= 1.0 and x > 0:
 
225
  pixel_x = int(x * width)
226
  pixel_y = int(y * height)
227
  else:
 
228
  pixel_x = int(x)
229
  pixel_y = int(y)
230
 
231
  action_type = act['type']
232
  color = colors.get(action_type, 'green')
233
 
234
+ # Draw Target
235
+ r = 12
236
  draw.ellipse(
237
  [pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r],
238
  outline=color,
239
  width=4
240
  )
 
 
241
  draw.ellipse(
242
  [pixel_x - 3, pixel_y - 3, pixel_x + 3, pixel_y + 3],
243
  fill=color
244
  )
245
 
246
+ # Label
247
  label_text = f"{action_type}"
248
  if act['text']:
249
  label_text += f": '{act['text']}'"
250
 
 
251
  text_pos = (pixel_x + 15, pixel_y - 10)
252
+ # Bounding box for text background
253
+ if font:
254
+ bbox = draw.textbbox(text_pos, label_text, font=font)
255
+ draw.rectangle(bbox, fill="black")
256
+ draw.text(text_pos, label_text, fill="white", font=font)
257
+ else:
258
+ draw.text(text_pos, label_text, fill="black") # fallback
259
 
260
  return img_copy
261
 
 
263
  # 4. GRADIO LOGIC
264
  # -----------------------------------------------------------------------------
265
 
266
+ @spaces.GPU(duration=120)
267
+ def process_screenshot(model_choice: str, input_numpy_image: np.ndarray, task: str) -> Tuple[str, Optional[Image.Image]]:
268
  if input_numpy_image is None:
269
  return "⚠️ Please upload an image first.", None
270
 
271
  # Convert to PIL
272
  input_pil_image = array_to_image(input_numpy_image)
273
 
274
+ # Build Prompt
275
  prompt = get_navigation_prompt(task, input_pil_image)
276
 
277
+ # Generate Response
278
+ print(f"Generating response with {model_choice}...")
279
+ try:
280
+ raw_response = model_manager.generate(model_choice, prompt, max_new_tokens=500)
281
+ except Exception as e:
282
+ return f"Error generating response: {str(e)}", None
283
+
284
  print(f"Raw Output:\n{raw_response}")
285
 
286
+ # Parse Actions
287
  actions = parse_tool_calls(raw_response)
288
 
289
+ # Visualize
290
  output_image = input_pil_image
291
  if actions:
292
  visualized = create_localized_image(input_pil_image, actions)
 
299
  # 5. GRADIO UI SETUP
300
  # -----------------------------------------------------------------------------
301
 
302
+ title = "CUA GUI Operator 🖥️"
303
  description = """
304
+ This demo uses **Vision Language Models** to understand GUI screenshots and generate actions.
305
+ Select a model, upload a screenshot, and define a task.
306
  """
307
 
308
  custom_css = """
 
315
 
316
  with gr.Row():
317
  with gr.Column():
318
+ # Model Selector
319
+ model_selector = gr.Dropdown(
320
+ label="Choose CUA Model",
321
+ choices=["Fara-7B", "UI-TARS-1.5-7B"],
322
+ value="Fara-7B",
323
+ interactive=True
324
+ )
325
+
326
  input_image = gr.Image(label="Upload Screenshot", height=500)
327
  task_input = gr.Textbox(
328
  label="Task Instruction",
 
338
  # Wire up the button
339
  submit_btn.click(
340
  fn=process_screenshot,
341
+ inputs=[model_selector, input_image, task_input],
342
  outputs=[output_text, output_image]
343
  )
344
 
345
  # Example for quick testing
346
  gr.Examples(
347
  examples=[
348
+ ["Fara-7B", "./assets/google.png", "Search for 'Hugging Face'"],
349
  ],
350
+ inputs=[model_selector, input_image, task_input],
351
  label="Quick Examples"
352
  )
353