prithivMLmods commited on
Commit
c875d85
·
verified ·
1 Parent(s): 36aecd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -359
app.py CHANGED
@@ -1,106 +1,159 @@
1
  import os
2
  import re
3
  import json
4
- import time
5
- import shutil
6
- import uuid
7
- import tempfile
8
- import unicodedata
9
- from io import BytesIO
10
- from typing import Tuple, Optional, List, Dict, Any
11
-
12
- import gradio as gr
13
  import numpy as np
14
  import torch
15
  import spaces
 
16
  from PIL import Image, ImageDraw, ImageFont
 
17
 
18
- # Transformers imports
19
  from transformers import (
20
  Qwen2_5_VLForConditionalGeneration,
21
  AutoProcessor,
22
  )
23
  from qwen_vl_utils import process_vision_info
24
 
25
- # Selenium Imports
26
- from selenium import webdriver
27
- from selenium.webdriver.chrome.service import Service as ChromeService
28
- from selenium.webdriver.chrome.options import Options as ChromeOptions
29
- from selenium.webdriver.common.action_chains import ActionChains
30
- from selenium.webdriver.common.by import By
31
- from selenium.webdriver.common.keys import Keys
32
- from webdriver_manager.chrome import ChromeDriverManager
33
-
34
  # -----------------------------------------------------------------------------
35
- # CONSTANTS & CONFIG
36
  # -----------------------------------------------------------------------------
37
 
38
- MODEL_ID = "microsoft/Fara-7B"
39
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
- WIDTH = 1024
41
- HEIGHT = 768
42
- TMP_DIR = "./tmp"
43
- if not os.path.exists(TMP_DIR):
44
- os.makedirs(TMP_DIR)
45
-
46
- # System Prompt
47
- # We ask for Python code, but we will also handle the JSON format the model seems to prefer.
48
- OS_SYSTEM_PROMPT = f"""You are a GUI automation agent controlling a Chrome browser.
49
- Current Screen Resolution: {WIDTH}x{HEIGHT}.
50
-
51
- You will receive a screenshot and a task. Generate the next action to complete the task.
52
-
53
- Supported Actions (Python Format):
54
- 1. `click(x=200, y=300)`: Left click at specific pixel coordinates.
55
- 2. `type_text(text="hello")`: Type text.
56
- 3. `press_key(key="enter")`: Press a key.
57
- 4. `scroll(amount=2, direction="down")`: Scroll.
58
- 5. `open_url(url="https://google.com")`: Open a URL.
59
-
60
- Important:
61
- - Use precise PIXEL coordinates from the screenshot.
62
- - Wrap your action in <code> tags.
63
- - If you need to search, click the search bar first, then type, then press enter.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
 
66
  # -----------------------------------------------------------------------------
67
- # MODEL WRAPPER
68
  # -----------------------------------------------------------------------------
69
 
70
- class FaraModelWrapper:
71
  def __init__(self, model_id: str, to_device: str = "cuda"):
72
- print(f"Loading {model_id} on {to_device}...")
73
  self.model_id = model_id
74
 
 
75
  try:
76
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
77
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
  model_id,
79
  trust_remote_code=True,
80
- torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
81
  device_map="auto" if to_device == "cuda" else None,
82
  )
83
  if to_device == "cpu":
84
  self.model.to("cpu")
85
- self.model.eval()
86
  print("Model loaded successfully.")
87
  except Exception as e:
88
- print(f"Failed to load Fara, falling back to Qwen2.5-VL-7B. Error: {e}")
89
- fallback_id = "Qwen/Qwen2.5-VL-7B-Instruct"
90
- self.processor = AutoProcessor.from_pretrained(fallback_id, trust_remote_code=True)
91
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
92
- fallback_id,
93
- trust_remote_code=True,
94
- torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
95
- device_map="auto",
96
- )
97
 
98
- def generate(self, messages: list[dict], max_new_tokens=512):
 
99
  text = self.processor.apply_chat_template(
100
  messages, tokenize=False, add_generation_prompt=True
101
  )
 
 
102
  image_inputs, video_inputs = process_vision_info(messages)
103
 
 
104
  inputs = self.processor(
105
  text=[text],
106
  images=image_inputs,
@@ -110,12 +163,10 @@ class FaraModelWrapper:
110
  )
111
  inputs = inputs.to(self.model.device)
112
 
113
- with torch.no_grad():
114
- generated_ids = self.model.generate(
115
- **inputs,
116
- max_new_tokens=max_new_tokens
117
- )
118
 
 
119
  generated_ids_trimmed = [
120
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
121
  ]
@@ -126,330 +177,240 @@ class FaraModelWrapper:
126
 
127
  return output_text
128
 
129
- # Initialize global model
130
- model = FaraModelWrapper(MODEL_ID, DEVICE)
131
-
132
  # -----------------------------------------------------------------------------
133
- # SELENIUM SANDBOX
134
  # -----------------------------------------------------------------------------
135
 
136
- def get_system_chrome_path():
137
- paths = ["/usr/bin/chromium", "/usr/bin/chromium-browser", "/usr/bin/google-chrome"]
138
- for p in paths:
139
- if os.path.exists(p): return p
140
- return None
141
-
142
- class SeleniumSandbox:
143
- def __init__(self, width=1024, height=768):
144
- self.width = width
145
- self.height = height
146
- self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_")
147
-
148
- chrome_opts = ChromeOptions()
149
- binary_path = get_system_chrome_path()
150
- if binary_path: chrome_opts.binary_location = binary_path
151
-
152
- chrome_opts.add_argument("--headless=new")
153
- chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}")
154
- chrome_opts.add_argument(f"--window-size={width},{height}")
155
- chrome_opts.add_argument("--no-sandbox")
156
- chrome_opts.add_argument("--disable-dev-shm-usage")
157
- chrome_opts.add_argument("--disable-gpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- try:
160
- system_driver_path = "/usr/bin/chromedriver"
161
- if os.path.exists(system_driver_path):
162
- service = ChromeService(executable_path=system_driver_path)
163
- else:
164
- service = ChromeService(ChromeDriverManager().install())
165
-
166
- self.driver = webdriver.Chrome(service=service, options=chrome_opts)
167
- self.driver.set_window_size(width, height)
168
- print("Selenium started.")
169
- except Exception as e:
170
- print(f"Selenium init failed: {e}")
171
- shutil.rmtree(self.tmp_dir, ignore_errors=True)
172
- raise e
173
 
174
- def get_screenshot(self):
175
- return Image.open(BytesIO(self.driver.get_screenshot_as_png()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- def execute_action(self, action_data: dict):
178
- """Execute parsed action on the browser"""
179
- action_type = action_data.get('type')
180
- print(f"Executing action: {action_data}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- try:
183
- actions = ActionChains(self.driver)
184
- body = self.driver.find_element(By.TAG_NAME, "body")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- # Helper: Handle both normalized (0-1) and pixel coordinates
187
- def get_coords(data):
188
- x, y = data.get('x', 0), data.get('y', 0)
189
- if x <= 1.0 and y <= 1.0 and x > 0: # Likely normalized
190
- x = int(x * self.width)
191
- y = int(y * self.height)
192
- else: # Likely pixels
193
- x = int(x)
194
- y = int(y)
195
- return x, y
196
-
197
- if action_type in ['click', 'left_click', 'right_click', 'double_click']:
198
- x_px, y_px = get_coords(action_data)
199
-
200
- # Reset pointer to top-left then move
201
- actions.move_to_element_with_offset(body, 0, 0)
202
- actions.move_by_offset(x_px, y_px)
203
-
204
- if action_type in ['click', 'left_click']: actions.click()
205
- elif action_type == 'right_click': actions.context_click()
206
- elif action_type == 'double_click': actions.double_click()
207
- actions.perform()
208
- return f"Clicked at {x_px}, {y_px}"
209
-
210
- elif action_type == 'type_text':
211
- text = action_data.get('text', '')
212
- press_enter = action_data.get('press_enter', False)
213
-
214
- # Check if this type action came with coordinates (from JSON log)
215
- if 'x' in action_data and 'y' in action_data:
216
- x_px, y_px = get_coords(action_data)
217
- actions.move_to_element_with_offset(body, 0, 0)
218
- actions.move_by_offset(x_px, y_px)
219
- actions.click()
220
-
221
- actions.send_keys(text)
222
- if press_enter:
223
- actions.send_keys(Keys.ENTER)
224
- actions.perform()
225
- return f"Typed '{text}'"
226
-
227
- elif action_type == 'press_key':
228
- key_name = action_data.get('key', '').lower()
229
- k = getattr(Keys, key_name.upper(), None)
230
- if not k:
231
- if key_name == "enter": k = Keys.ENTER
232
- elif key_name == "space": k = Keys.SPACE
233
- if k:
234
- actions.send_keys(k)
235
- actions.perform()
236
- return f"Pressed {key_name}"
237
 
238
- elif action_type == 'scroll':
239
- amount = action_data.get('amount', 2)
240
- scroll_y = amount * 100
241
- self.driver.execute_script(f"window.scrollBy(0, {scroll_y});")
242
- return "Scrolled"
243
-
244
- elif action_type == 'open_url':
245
- url = action_data.get('url', '')
246
- if not url.startswith('http'): url = 'https://' + url
247
- self.driver.get(url)
248
- time.sleep(2) # Wait for load
249
- return f"Opened {url}"
250
-
251
- return f"Unknown action {action_type}"
252
- except Exception as e:
253
- return f"Action failed: {e}"
254
-
255
- def cleanup(self):
256
- try: self.driver.quit()
257
- except: pass
258
- shutil.rmtree(self.tmp_dir, ignore_errors=True)
259
 
260
  # -----------------------------------------------------------------------------
261
- # PARSING LOGIC (Fixed to handle JSON logs)
262
  # -----------------------------------------------------------------------------
263
 
264
- def parse_model_response(response: str) -> dict:
265
- """
266
- Parses both:
267
- 1. <code>click(x=...)</code> (Python style)
268
- 2. <tool_call>{...}</tool_call> (JSON style seen in logs)
269
- """
270
-
271
- # Check for JSON Tool Call first (Priority based on logs)
272
- tool_match = re.search(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
273
- if tool_match:
274
- try:
275
- tool_data = json.loads(tool_match.group(1))
276
- name = tool_data.get("name")
277
- args = tool_data.get("arguments", {})
278
-
279
- # Map JSON schema to our internal schema
280
- if name == "Navigate":
281
- if "url" in args:
282
- return {"type": "open_url", "url": args["url"]}
283
- elif "action" in args:
284
- action_sub = args["action"]
285
- coords = args.get("coordinate", [0, 0])
286
- x, y = coords[0], coords[1]
287
-
288
- if action_sub == "left_click":
289
- return {"type": "click", "x": x, "y": y}
290
- elif action_sub == "type":
291
- text = args.get("text", "")
292
- enter = args.get("press_enter", False)
293
- return {"type": "type_text", "text": text, "x": x, "y": y, "press_enter": enter}
294
-
295
- elif name == "Type":
296
- return {
297
- "type": "type_text",
298
- "text": args.get("text", ""),
299
- "press_enter": args.get("press_enter", False)
300
- }
301
-
302
- except Exception as e:
303
- print(f"JSON Parse Error: {e}")
304
 
305
- # Fallback to Python Code Block
306
- code_match = re.search(r"<code>\s*(.*?)\s*</code>", response, re.DOTALL)
307
- if code_match:
308
- action_str = code_match.group(1)
309
-
310
- # Regex for Python style
311
- coord_match = re.match(r"(\w+)\s*\(\s*x\s*=\s*([0-9.]+)\s*,\s*y\s*=\s*([0-9.]+)\s*\)", action_str)
312
- if coord_match:
313
- return {"type": coord_match.group(1), "x": float(coord_match.group(2)), "y": float(coord_match.group(3))}
314
-
315
- url_match = re.match(r"open_url\s*\(\s*url\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
316
- if url_match: return {"type": "open_url", "url": url_match.group(1)}
317
-
318
- text_match = re.match(r"type_text\s*\(\s*text\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
319
- if text_match: return {"type": "type_text", "text": text_match.group(1)}
320
-
321
- return {}
322
 
323
  # -----------------------------------------------------------------------------
324
- # MAIN LOOP
325
  # -----------------------------------------------------------------------------
326
 
327
- @spaces.GPU(duration=180)
328
- def agent_step(task_instruction: str, history: list, sandbox_state: dict):
329
- # Retrieve or Create Sandbox
330
- if 'uuid' not in sandbox_state:
331
- sandbox_state['uuid'] = str(uuid.uuid4())
332
-
333
- sid = sandbox_state['uuid']
334
- if sid not in SANDBOX_REGISTRY:
335
- SANDBOX_REGISTRY[sid] = SeleniumSandbox(WIDTH, HEIGHT)
336
-
337
- sandbox = SANDBOX_REGISTRY[sid]
338
-
339
- # 1. Get Screenshot
340
- screenshot = sandbox.get_screenshot()
341
-
342
- # 2. Construct Prompt
343
- # We append the history of actions to help the model know state
344
- history_text = "\n".join([h.split('\nAction:')[1].strip() if 'Action:' in h else '' for h in history[-3:]])
345
 
346
- messages = [
347
- {"role": "system", "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]},
348
- {"role": "user", "content": [
349
- {"type": "image", "image": screenshot},
350
- {"type": "text", "text": f"Task: {task_instruction}\nPrevious Actions Summary: {history_text}"}
351
- ]}
352
- ]
353
 
354
- # 3. Model Inference
355
- response = model.generate(messages)
 
 
356
 
357
- # 4. Parse Action
358
- action_data = parse_model_response(response)
359
 
360
- log_entry = f"Step: {len(history)+1}\nThought: {response}\nParsed: {action_data}"
 
 
 
 
361
 
362
- # 5. Execute Action
363
- if action_data:
364
- execution_result = sandbox.execute_action(action_data)
365
-
366
- # Visual Marker
367
- if 'x' in action_data and 'y' in action_data:
368
- draw = ImageDraw.Draw(screenshot)
369
- # Handle mixed coord types for drawing
370
- x_px = action_data['x']
371
- y_px = action_data['y']
372
- if x_px <= 1.0: x_px *= WIDTH
373
- if y_px <= 1.0: y_px *= HEIGHT
374
-
375
- r = 10
376
- draw.ellipse((x_px-r, y_px-r, x_px+r, y_px+r), outline="red", width=3)
377
- else:
378
- execution_result = "No valid action parsed."
379
 
380
- log_entry += f"\nResult: {execution_result}"
381
- history.append(log_entry)
382
-
383
- return screenshot, history, sandbox_state
384
-
385
- # Global registry
386
- SANDBOX_REGISTRY = {}
387
-
388
- def cleanup_sandbox(sandbox_state):
389
- sid = sandbox_state.get('uuid')
390
- if sid and sid in SANDBOX_REGISTRY:
391
- SANDBOX_REGISTRY[sid].cleanup()
392
- del SANDBOX_REGISTRY[sid]
393
- return [], {}
394
-
395
- # -----------------------------------------------------------------------------
396
- # GRADIO UI
397
- # -----------------------------------------------------------------------------
398
 
399
- def run_task_loop(task, history, state):
400
- max_steps = 15
401
-
402
- for i in range(max_steps):
403
- try:
404
- screenshot, new_history, new_state = agent_step(task, history, state)
405
- history = new_history
406
-
407
- logs_text = "\n\n" + "="*40 + "\n\n".join(history)
408
- yield screenshot, logs_text, state
409
-
410
- if "Done" in history[-1] or "finished" in history[-1].lower():
411
- break
412
-
413
- time.sleep(1)
414
- except Exception as e:
415
- error_msg = f"Error in loop: {e}"
416
- history.append(error_msg)
417
- yield None, "\n".join(history), state
418
- break
419
 
420
- custom_css = "#view_img { height: 600px; object-fit: contain; }"
 
 
421
 
422
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
423
- state = gr.State({})
424
- history = gr.State([])
425
-
426
- gr.Markdown("# 🤖 Fara CUA - Chrome Agent")
427
-
428
  with gr.Row():
429
  with gr.Column(scale=1):
430
- task_input = gr.Textbox(label="Task Instruction", value="Go to google.com and search for 'SpaceX'")
431
- run_btn = gr.Button("Run Agent", variant="primary")
432
- clear_btn = gr.Button("Reset / Clear")
433
-
434
- with gr.Column(scale=2):
435
- browser_view = gr.Image(label="Live Browser View", elem_id="view_img", interactive=False)
436
 
437
- logs_output = gr.Textbox(label="Agent Logs", lines=15, interactive=False)
438
-
439
- run_btn.click(
440
- fn=run_task_loop,
441
- inputs=[task_input, history, state],
442
- outputs=[browser_view, logs_output, state]
443
- )
444
 
445
- clear_btn.click(
446
- fn=cleanup_sandbox,
447
- inputs=[state],
448
- outputs=[history, state]
449
- ).then(
450
- lambda: (None, ""),
451
- outputs=[browser_view, logs_output]
452
  )
453
 
 
 
 
454
  if __name__ == "__main__":
455
- demo.launch(share=True)
 
1
  import os
2
  import re
3
  import json
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
  import torch
6
  import spaces
7
+ import gradio as gr
8
  from PIL import Image, ImageDraw, ImageFont
9
+ from typing import Tuple, Optional, List, Dict, Any
10
 
11
+ # Transformers & Qwen Utils
12
  from transformers import (
13
  Qwen2_5_VLForConditionalGeneration,
14
  AutoProcessor,
15
  )
16
  from qwen_vl_utils import process_vision_info
17
 
 
 
 
 
 
 
 
 
 
18
  # -----------------------------------------------------------------------------
19
+ # 1. PROMPTS (from prompt.py)
20
  # -----------------------------------------------------------------------------
21
 
22
+ OS_ACTIONS = """
23
+ def final_answer(answer: any) -> any:
24
+ \"\"\"
25
+ Provides a final answer to the given problem.
26
+ Args:
27
+ answer: The final answer to the problem
28
+ \"\"\"
29
+
30
+ def move_mouse(self, x: float, y: float) -> str:
31
+ \"\"\"
32
+ Moves the mouse cursor to the specified coordinates
33
+ Args:
34
+ x: The x coordinate (horizontal position)
35
+ y: The y coordinate (vertical position)
36
+ \"\"\"
37
+
38
+ def click(x: Optional[float] = None, y: Optional[float] = None) -> str:
39
+ \"\"\"
40
+ Performs a left-click at the specified normalized coordinates
41
+ Args:
42
+ x: The x coordinate (horizontal position)
43
+ y: The y coordinate (vertical position)
44
+ \"\"\"
45
+
46
+ def double_click(x: Optional[float] = None, y: Optional[float] = None) -> str:
47
+ \"\"\"
48
+ Performs a double-click at the specified normalized coordinates
49
+ Args:
50
+ x: The x coordinate (horizontal position)
51
+ y: The y coordinate (vertical position)
52
+ \"\"\"
53
+
54
+ def type(text: str) -> str:
55
+ \"\"\"
56
+ Types the specified text at the current cursor position.
57
+ Args:
58
+ text: The text to type
59
+ \"\"\"
60
+
61
+ def press(keys: str | list[str]) -> str:
62
+ \"\"\"
63
+ Presses a keyboard key
64
+ Args:
65
+ keys: The key or list of keys to press (e.g. "enter", "space", "backspace", "ctrl", etc.).
66
+ \"\"\"
67
+
68
+ def navigate_back() -> str:
69
+ \"\"\"
70
+ Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly.
71
+ \"\"\"
72
+
73
+ def drag(from_coord: list[float], to_coord: list[float]) -> str:
74
+ \"\"\"
75
+ Clicks [x1, y1], drags mouse to [x2, y2], then release click.
76
+ Args:
77
+ x1: origin x coordinate
78
+ y1: origin y coordinate
79
+ x2: end x coordinate
80
+ y2: end y coordinate
81
+ \"\"\"
82
+
83
+ def scroll(direction: Literal["up", "down"] = "down", amount: int = 1) -> str:
84
+ \"\"\"
85
+ Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus.
86
+ Args:
87
+ x: The x coordinate (horizontal position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
88
+ y: The y coordinate (vertical position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates
89
+ direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out.
90
+ amount: The amount to scroll. A good amount is 1 or 2.
91
+ \"\"\"
92
+
93
+ def wait(seconds: float) -> str:
94
+ \"\"\"
95
+ Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps)
96
+ Args:
97
+ seconds: Number of seconds to wait, generally 2 is enough.
98
+ \"\"\"
99
+ """
100
+
101
+ OS_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls.
102
+
103
+ For each step:
104
+ • First, <think></think> to express the thought process guiding your next action and the reasoning behind it.
105
+ • Then, use <code></code> to perform the action. it will be executed in a stateful environment.
106
+
107
+ The following functions are exposed to the Python interpreter:
108
+ <code>
109
+ {OS_ACTIONS}
110
+ </code>
111
+
112
+ The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
113
  """
114
 
115
  # -----------------------------------------------------------------------------
116
+ # 2. MODEL WRAPPER (Modified for Fara/QwenVL)
117
  # -----------------------------------------------------------------------------
118
 
119
+ class TransformersModel:
120
  def __init__(self, model_id: str, to_device: str = "cuda"):
121
+ print(f"Loading model: {model_id}...")
122
  self.model_id = model_id
123
 
124
+ # Load Processor
125
  try:
126
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
127
+ except Exception as e:
128
+ print(f"Error loading processor: {e}")
129
+ raise e
130
+
131
+ # Load Model
132
+ try:
133
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
134
  model_id,
135
  trust_remote_code=True,
136
+ torch_dtype=torch.bfloat16,
137
  device_map="auto" if to_device == "cuda" else None,
138
  )
139
  if to_device == "cpu":
140
  self.model.to("cpu")
141
+
142
  print("Model loaded successfully.")
143
  except Exception as e:
144
+ print(f"Error loading Fara/Qwen model: {e}. Ensure you have access/internet.")
145
+ raise e
 
 
 
 
 
 
 
146
 
147
+ def generate(self, messages: list[dict], **kwargs):
148
+ # 1. Prepare text prompt using chat template
149
  text = self.processor.apply_chat_template(
150
  messages, tokenize=False, add_generation_prompt=True
151
  )
152
+
153
+ # 2. Process images/videos
154
  image_inputs, video_inputs = process_vision_info(messages)
155
 
156
+ # 3. Create model inputs
157
  inputs = self.processor(
158
  text=[text],
159
  images=image_inputs,
 
163
  )
164
  inputs = inputs.to(self.model.device)
165
 
166
+ # 4. Generate
167
+ generated_ids = self.model.generate(**inputs, **kwargs)
 
 
 
168
 
169
+ # 5. Decode (trimming input tokens)
170
  generated_ids_trimmed = [
171
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
172
  ]
 
177
 
178
  return output_text
179
 
 
 
 
180
  # -----------------------------------------------------------------------------
181
+ # 3. HELPER FUNCTIONS
182
  # -----------------------------------------------------------------------------
183
 
184
+ def array_to_image(image_array: np.ndarray) -> Image.Image:
185
+ if image_array is None:
186
+ raise ValueError("No image provided. Please upload an image before submitting.")
187
+ return Image.fromarray(np.uint8(image_array))
188
+
189
+ def get_navigation_prompt(task, image):
190
+ """Constructs the prompt messages for the model"""
191
+ return [
192
+ {
193
+ "role": "system",
194
+ "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}],
195
+ },
196
+ {
197
+ "role": "user",
198
+ "content": [
199
+ {"type": "image", "image": image},
200
+ {"type": "text", "text": f"Instruction: {task}\n\nPrevious actions:\nNone"},
201
+ ],
202
+ },
203
+ ]
204
+
205
+ def parse_actions_from_response(response: str) -> list[str]:
206
+ """Parse actions from model response using regex pattern."""
207
+ # Look for code block
208
+ pattern = r"<code>\s*(.*?)\s*</code>"
209
+ matches = re.findall(pattern, response, re.DOTALL)
210
+
211
+ # If no code block, try to find raw function calls if the model forgot tags
212
+ if not matches:
213
+ # Fallback: look for lines starting with known functions
214
+ funcs = ["click", "type", "press", "drag", "scroll", "wait"]
215
+ lines = response.split('\n')
216
+ found = []
217
+ for line in lines:
218
+ line = line.strip()
219
+ if any(line.startswith(f) for f in funcs):
220
+ found.append(line)
221
+ return found
222
 
223
+ return matches
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ def extract_coordinates_from_action(action_code: str) -> list[dict]:
226
+ """Extract coordinates from action code for localization actions."""
227
+ localization_actions = []
228
+
229
+ # Patterns for different action types
230
+ patterns = {
231
+ 'click': r'click\((?:x=)?([0-9.]+)(?:,\s*(?:y=)?([0-9.]+))?\)',
232
+ 'double_click': r'double_click\((?:x=)?([0-9.]+)(?:,\s*(?:y=)?([0-9.]+))?\)',
233
+ 'move_mouse': r'move_mouse\((?:self,\s*)?(?:x=)?([0-9.]+)(?:,\s*(?:y=)?([0-9.]+))\)',
234
+ 'drag': r'drag\(\[([0-9.]+),\s*([0-9.]+)\],\s*\[([0-9.]+),\s*([0-9.]+)\]\)'
235
+ }
236
+
237
+ for action_type, pattern in patterns.items():
238
+ matches = re.finditer(pattern, action_code)
239
+ for match in matches:
240
+ if action_type == 'drag':
241
+ # Drag has from and to coordinates
242
+ from_x, from_y, to_x, to_y = match.groups()
243
+ localization_actions.append({
244
+ 'type': 'drag_from', 'x': float(from_x), 'y': float(from_y), 'action': action_type
245
+ })
246
+ localization_actions.append({
247
+ 'type': 'drag_to', 'x': float(to_x), 'y': float(to_y), 'action': action_type
248
+ })
249
+ else:
250
+ # Single coordinate actions
251
+ if match.groups()[0]:
252
+ x_val = match.group(1)
253
+ y_val = match.group(2) if match.group(2) else x_val
254
+
255
+ # Convert pixel coords to normalized if they look like pixels (assuming > 1000 width usually)
256
+ # Note: The prompt implies normalized (0.0-1.0), but if model outputs 500, we handle it visually later
257
+
258
+ if x_val and y_val:
259
+ localization_actions.append({
260
+ 'type': action_type,
261
+ 'x': float(x_val),
262
+ 'y': float(y_val),
263
+ 'action': action_type
264
+ })
265
+
266
+ return localization_actions
267
 
268
+ def create_localized_image(original_image: Image.Image, coordinates: list[dict]) -> Optional[Image.Image]:
269
+ """Create an image with localization markers drawn on it."""
270
+ if not coordinates:
271
+ return None
272
+
273
+ img_copy = original_image.copy()
274
+ draw = ImageDraw.Draw(img_copy)
275
+ width, height = img_copy.size
276
+
277
+ try:
278
+ font = ImageFont.load_default()
279
+ except:
280
+ font = None
281
+
282
+ colors = {
283
+ 'click': 'red', 'double_click': 'blue', 'move_mouse': 'green',
284
+ 'drag_from': 'orange', 'drag_to': 'purple'
285
+ }
286
+
287
+ for i, coord in enumerate(coordinates):
288
+ # Handle normalized vs pixel coordinates
289
+ x, y = coord['x'], coord['y']
290
 
291
+ if x <= 1.0 and y <= 1.0:
292
+ pixel_x = int(x * width)
293
+ pixel_y = int(y * height)
294
+ else:
295
+ pixel_x = int(x)
296
+ pixel_y = int(y)
297
+
298
+ color = colors.get(coord['type'], 'red')
299
+
300
+ # Draw Circle
301
+ r = 8
302
+ draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r],
303
+ fill=color, outline='white', width=2)
304
+
305
+ # Draw Label
306
+ label = f"{coord['type']}"
307
+ text_pos = (pixel_x + 10, pixel_y - 10)
308
+ if font:
309
+ draw.text(text_pos, label, fill=color, font=font)
310
+ else:
311
+ draw.text(text_pos, label, fill=color)
312
+
313
+ # Draw Arrow for Drag
314
+ if coord['type'] == 'drag_from' and i + 1 < len(coordinates) and coordinates[i + 1]['type'] == 'drag_to':
315
+ next_coord = coordinates[i + 1]
316
+ nx, ny = next_coord['x'], next_coord['y']
317
 
318
+ if nx <= 1.0 and ny <= 1.0:
319
+ end_x, end_y = int(nx * width), int(ny * height)
320
+ else:
321
+ end_x, end_y = int(nx), int(ny)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ draw.line([pixel_x, pixel_y, end_x, end_y], fill='orange', width=3)
324
+
325
+ return img_copy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  # -----------------------------------------------------------------------------
328
+ # 4. INITIALIZATION
329
  # -----------------------------------------------------------------------------
330
 
331
+ # Using Fara-7B (or fallback)
332
+ MODEL_ID = "microsoft/Fara-7B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
+ print(f"Initializing {MODEL_ID}...")
335
+ # Global model instance
336
+ # Note: We initialize this lazily or globally depending on environment.
337
+ # For Gradio Spaces, global init is standard.
338
+ try:
339
+ model = TransformersModel(model_id=MODEL_ID, to_device="cuda" if torch.cuda.is_available() else "cpu")
340
+ except Exception as e:
341
+ print(f"Failed to load Fara. Trying fallback Qwen...")
342
+ model = TransformersModel(model_id="Qwen/Qwen2.5-VL-7B-Instruct", to_device="cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
343
 
344
  # -----------------------------------------------------------------------------
345
+ # 5. GRADIO APP
346
  # -----------------------------------------------------------------------------
347
 
348
+ @spaces.GPU
349
+ def navigate(input_numpy_image: np.ndarray, task: str) -> Tuple[str, Optional[Image.Image]]:
350
+ if input_numpy_image is None:
351
+ return "Please upload an image.", None
352
+
353
+ input_pil_image = array_to_image(input_numpy_image)
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
+ # Generate Prompt
356
+ prompt_msgs = get_navigation_prompt(task, input_pil_image)
 
 
 
 
 
357
 
358
+ # Generate Response
359
+ print("Generating response...")
360
+ response_str = model.generate(prompt_msgs, max_new_tokens=500)
361
+ print(f"Model Response: {response_str}")
362
 
363
+ # Parse
364
+ actions = parse_actions_from_response(response_str)
365
 
366
+ # Extract Coordinates
367
+ all_coordinates = []
368
+ for action_code in actions:
369
+ coords = extract_coordinates_from_action(action_code)
370
+ all_coordinates.extend(coords)
371
 
372
+ # Visualize
373
+ localized_image = input_pil_image
374
+ if all_coordinates:
375
+ localized_image = create_localized_image(input_pil_image, all_coordinates)
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
+ return response_str, localized_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
+ title = "Fara-7B GUI Operator 🤖"
380
+ description = """
381
+ ### Fara GUI Agent Demo
382
+ Upload a screenshot and give an instruction. The model will analyze the UI and output the Python code to execute the action.
383
+ This demo visualizes where the model wants to click or drag.
384
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
387
+ gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
388
+ gr.Markdown(description)
389
 
390
+ with gr.Row():
391
+ input_image = gr.Image(label="Upload Screenshot", height=500, type="numpy")
392
+
 
 
 
393
  with gr.Row():
394
  with gr.Column(scale=1):
395
+ task_input = gr.Textbox(
396
+ label="Instruction",
397
+ placeholder="e.g. Click on the Search button...",
398
+ lines=2
399
+ )
400
+ submit_btn = gr.Button("Generate Action", variant="primary")
401
 
402
+ with gr.Column(scale=1):
403
+ output_code = gr.Textbox(label="Generated Python Code", lines=10)
 
 
 
 
 
404
 
405
+ # Output image gets updated with markers
406
+ submit_btn.click(
407
+ fn=navigate,
408
+ inputs=[input_image, task_input],
409
+ outputs=[output_code, input_image]
 
 
410
  )
411
 
412
+ # Optional: Examples
413
+ # gr.Examples(...)
414
+
415
  if __name__ == "__main__":
416
+ demo.launch()