prithivMLmods commited on
Commit
17659d9
·
verified ·
1 Parent(s): 7e47e30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -203
app.py CHANGED
@@ -3,16 +3,15 @@ import re
3
  import time
4
  import shutil
5
  import uuid
 
6
  import tempfile
7
- import unicodedata
8
  from io import BytesIO
9
- from typing import Tuple, Optional, List, Dict, Any
10
 
11
  import gradio as gr
12
- import numpy as np
13
  import torch
14
  import spaces
15
- from PIL import Image, ImageDraw, ImageFont
16
 
17
  # Transformers imports
18
  from transformers import (
@@ -34,7 +33,10 @@ from webdriver_manager.chrome import ChromeDriverManager
34
  # CONSTANTS & CONFIG
35
  # -----------------------------------------------------------------------------
36
 
37
- MODEL_ID = "microsoft/Fara-7B" # Or your specific Fara model repo
 
 
 
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
  WIDTH = 1024
40
  HEIGHT = 768
@@ -42,35 +44,40 @@ TMP_DIR = "./tmp"
42
  if not os.path.exists(TMP_DIR):
43
  os.makedirs(TMP_DIR)
44
 
45
- # System Prompt adapted for Fara/GUI agents
46
- OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
47
- You need to generate the next action to complete the task.
48
-
49
- Supported actions:
50
- 1. `click(x=0.5, y=0.5)`: Click at the specific location.
51
- 2. `right_click(x=0.5, y=0.5)`: Right click at the specific location.
52
- 3. `double_click(x=0.5, y=0.5)`: Double click at the specific location.
53
- 4. `type_text(text="hello")`: Type the text.
54
- 5. `scroll(amount=2, direction="down")`: Scroll the page.
55
- 6. `press_key(key="enter")`: Press a specific key.
56
- 7. `open_url(url="https://google.com")`: Open a specific URL.
57
-
58
- Output format:
59
- Please wrap the action code in <code> </code> tags.
 
 
 
 
 
60
  Example:
61
- <code>
62
- click(x=0.23, y=0.45)
63
- </code>
64
  """
65
 
66
  # -----------------------------------------------------------------------------
67
- # MODEL WRAPPER (Replacing smolagents)
68
  # -----------------------------------------------------------------------------
69
 
70
- class FaraModelWrapper:
71
  def __init__(self, model_id: str, to_device: str = "cuda"):
72
- print(f"Loading {model_id} on {to_device}...")
73
- self.model_id = model_id
74
 
75
  try:
76
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
@@ -80,27 +87,25 @@ class FaraModelWrapper:
80
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
81
  device_map="auto" if to_device == "cuda" else None,
82
  )
83
- if to_device == "cpu":
84
- self.model.to("cpu")
85
- self.model.eval()
86
- print("Model loaded successfully.")
87
  except Exception as e:
88
- print(f"Failed to load Fara, falling back to Qwen2.5-VL-7B for demo compatibility. Error: {e}")
89
- fallback_id = "Qwen/Qwen2.5-VL-7B-Instruct"
90
- self.processor = AutoProcessor.from_pretrained(fallback_id, trust_remote_code=True)
91
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
92
- fallback_id,
93
  trust_remote_code=True,
94
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
95
- device_map="auto",
96
  )
 
 
 
 
 
97
 
98
  def generate(self, messages: list[dict], max_new_tokens=512):
99
- # Prepare inputs for Fara/QwenVL
100
  text = self.processor.apply_chat_template(
101
  messages, tokenize=False, add_generation_prompt=True
102
  )
103
-
104
  image_inputs, video_inputs = process_vision_info(messages)
105
 
106
  inputs = self.processor(
@@ -113,24 +118,19 @@ class FaraModelWrapper:
113
  inputs = inputs.to(self.model.device)
114
 
115
  with torch.no_grad():
116
- generated_ids = self.model.generate(
117
- **inputs,
118
- max_new_tokens=max_new_tokens
119
- )
120
 
121
- # Trim input tokens
122
  generated_ids_trimmed = [
123
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
124
  ]
125
-
126
  output_text = self.processor.batch_decode(
127
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
128
  )[0]
129
 
130
  return output_text
131
 
132
- # Initialize global model
133
- model = FaraModelWrapper(MODEL_ID, DEVICE)
134
 
135
  # -----------------------------------------------------------------------------
136
  # SELENIUM SANDBOX
@@ -168,6 +168,9 @@ class SeleniumSandbox:
168
 
169
  self.driver = webdriver.Chrome(service=service, options=chrome_opts)
170
  self.driver.set_window_size(width, height)
 
 
 
171
  print("Selenium started.")
172
  except Exception as e:
173
  print(f"Selenium init failed: {e}")
@@ -178,59 +181,66 @@ class SeleniumSandbox:
178
  return Image.open(BytesIO(self.driver.get_screenshot_as_png()))
179
 
180
  def execute_action(self, action_data: dict):
181
- """Execute parsed action on the browser"""
182
- action_type = action_data.get('type')
 
 
 
183
 
184
  try:
185
  actions = ActionChains(self.driver)
186
  body = self.driver.find_element(By.TAG_NAME, "body")
187
-
188
- # Helper to move to coordinates
189
- def move_to(x_norm, y_norm):
190
- # Convert normalized (0-1) to pixel coordinates
 
 
 
 
191
  x_px = int(x_norm * self.width)
192
  y_px = int(y_norm * self.height)
 
 
193
  actions.move_to_element_with_offset(body, 0, 0)
194
  actions.move_by_offset(x_px, y_px)
195
-
196
- if action_type in ['click', 'right_click', 'double_click']:
197
- move_to(action_data['x'], action_data['y'])
198
- if action_type == 'click': actions.click()
199
- elif action_type == 'right_click': actions.context_click()
200
- elif action_type == 'double_click': actions.double_click()
201
  actions.perform()
202
 
203
- elif action_type == 'type_text':
204
- text = action_data.get('text', '')
 
 
 
 
 
 
 
 
 
 
 
 
205
  actions.send_keys(text)
 
 
206
  actions.perform()
207
-
208
- elif action_type == 'press_key':
209
- key_name = action_data.get('key', '').lower()
210
- k = getattr(Keys, key_name.upper(), None)
211
- if not k:
212
- if key_name == "enter": k = Keys.ENTER
213
- elif key_name == "space": k = Keys.SPACE
214
- elif key_name == "backspace": k = Keys.BACK_SPACE
215
- if k:
216
- actions.send_keys(k)
217
- actions.perform()
218
-
219
- elif action_type == 'scroll':
220
- amount = action_data.get('amount', 2)
221
- direction = action_data.get('direction', 'down')
222
- scroll_y = amount * 100
223
- if direction == 'up': scroll_y = -scroll_y
224
- self.driver.execute_script(f"window.scrollBy(0, {scroll_y});")
225
-
226
- elif action_type == 'open_url':
227
- url = action_data.get('url', '')
228
- if not url.startswith('http'): url = 'https://' + url
229
- self.driver.get(url)
230
- time.sleep(2)
231
-
232
  return f"Executed {action_type}"
 
233
  except Exception as e:
 
234
  return f"Action failed: {e}"
235
 
236
  def cleanup(self):
@@ -239,124 +249,93 @@ class SeleniumSandbox:
239
  shutil.rmtree(self.tmp_dir, ignore_errors=True)
240
 
241
  # -----------------------------------------------------------------------------
242
- # PARSING LOGIC
243
  # -----------------------------------------------------------------------------
244
 
245
- def parse_code_block(response: str) -> str:
246
- pattern = r"<code>\s*(.*?)\s*</code>"
247
- matches = re.findall(pattern, response, re.DOTALL)
248
- if matches:
249
- return matches[-1].strip() # Return the last code block
250
- return ""
251
-
252
- def parse_action_string(action_str: str) -> dict:
253
- """Parse string like 'click(x=0.5, y=0.5)' into a dict"""
254
- # Simple regex parsing for demonstration
255
- action_data = {}
256
 
257
- # 1. Coordinate actions: name(x=..., y=...)
258
- coord_match = re.match(r"(\w+)\s*\(\s*x\s*=\s*([0-9.]+)\s*,\s*y\s*=\s*([0-9.]+)\s*\)", action_str)
259
- if coord_match:
260
- return {
261
- "type": coord_match.group(1),
262
- "x": float(coord_match.group(2)),
263
- "y": float(coord_match.group(3))
264
- }
265
-
266
- # 2. Open URL: open_url(url="...")
267
- url_match = re.match(r"open_url\s*\(\s*url\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
268
- if url_match:
269
- return {"type": "open_url", "url": url_match.group(1)}
270
-
271
- # 3. Type text: type_text(text="...")
272
- text_match = re.match(r"type_text\s*\(\s*text\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
273
- if text_match:
274
- return {"type": "type_text", "text": text_match.group(1)}
275
-
276
- # 4. Press key: press_key(key="...")
277
- key_match = re.match(r"press_key\s*\(\s*key\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
278
- if key_match:
279
- return {"type": "press_key", "key": key_match.group(1)}
280
-
281
- # 5. Scroll: scroll(amount=..., direction="...")
282
- if "scroll" in action_str:
283
- return {"type": "scroll", "amount": 2, "direction": "down"} # Default
284
-
285
- return {}
286
 
287
  # -----------------------------------------------------------------------------
288
- # MAIN LOOP
289
  # -----------------------------------------------------------------------------
290
 
 
 
 
291
  @spaces.GPU(duration=120)
292
  def agent_step(task_instruction: str, history: list, sandbox_state: dict):
293
- # Initialize sandbox if needed (handled via state in Gradio mostly, but for safety)
294
  if 'uuid' not in sandbox_state:
295
  sandbox_state['uuid'] = str(uuid.uuid4())
296
- sandbox = SeleniumSandbox(WIDTH, HEIGHT)
297
- # Store sandbox instance reference globally or handle cleanup carefully
298
- # For this demo, we'll recreate/attach to session based on state if persisting,
299
- # but here we'll assume a persistent session for the run.
300
 
301
- # HACK: For Gradio state persistence with objects that can't be pickled easily,
302
- # we often use a global dict mapping UUID -> Sandbox
303
- sandbox_id = sandbox_state['uuid']
304
- if sandbox_id not in SANDBOX_REGISTRY:
305
- SANDBOX_REGISTRY[sandbox_id] = SeleniumSandbox(WIDTH, HEIGHT)
306
 
307
- sandbox = SANDBOX_REGISTRY[sandbox_id]
308
 
309
- # 1. Get Screenshot
310
  screenshot = sandbox.get_screenshot()
311
 
312
- # 2. Construct Prompt
313
- # Convert history text to string context if needed
 
314
 
315
  messages = [
 
316
  {
317
- "role": "system",
318
- "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]
319
- },
320
- {
321
- "role": "user",
322
  "content": [
323
  {"type": "image", "image": screenshot},
324
- {"type": "text", "text": f"Instruction: {task_instruction}\nPrevious Actions: {history[-1] if history else 'None'}"}
325
  ]
326
  }
327
  ]
328
-
329
- # 3. Model Inference
330
  response = model.generate(messages)
331
 
332
- # 4. Parse Action
333
- action_code = parse_code_block(response)
334
- action_data = parse_action_string(action_code)
335
 
336
- log_entry = f"Step: {len(history)+1}\nModel Thought: {response}\nAction: {action_code}"
337
 
338
- # 5. Execute Action
339
- execution_result = "No valid action found"
340
  if action_data:
341
- execution_result = sandbox.execute_action(action_data)
 
342
 
343
- # Draw marker if coordinate action
344
- if 'x' in action_data:
 
345
  draw = ImageDraw.Draw(screenshot)
346
- x_px = action_data['x'] * WIDTH
347
- y_px = action_data['y'] * HEIGHT
348
- r = 10
349
- draw.ellipse((x_px-r, y_px-r, x_px+r, y_px+r), outline="red", width=3)
 
 
 
350
 
351
- log_entry += f"\nResult: {execution_result}"
352
  history.append(log_entry)
353
 
354
- # Return updated screenshot and history
355
  return screenshot, history, sandbox_state
356
 
357
- # Global registry for sandboxes
358
- SANDBOX_REGISTRY = {}
359
-
360
  def cleanup_sandbox(sandbox_state):
361
  sid = sandbox_state.get('uuid')
362
  if sid and sid in SANDBOX_REGISTRY:
@@ -368,70 +347,75 @@ def cleanup_sandbox(sandbox_state):
368
  # GRADIO UI
369
  # -----------------------------------------------------------------------------
370
 
371
- def run_task_loop(task, history, state):
372
- # This generator function runs the agent loop
373
- max_steps = 10
374
-
375
- for i in range(max_steps):
376
  try:
377
- # Run one step
378
- screenshot, new_history, new_state = agent_step(task, history, state)
379
- history = new_history
380
 
381
- # Yield updates to UI
382
- # We yield the logs (joined) and the latest image
383
- logs_text = "\n\n" + "-"*40 + "\n\n".join(history)
384
- yield screenshot, logs_text, state
385
-
386
- # Check for termination (simplistic)
387
- if "Done" in history[-1] or "finished" in history[-1].lower():
388
- break
389
-
390
- time.sleep(1) # Pause for visual effect
391
 
 
 
392
  except Exception as e:
393
- error_msg = f"Error in loop: {e}"
394
- history.append(error_msg)
395
  yield None, "\n".join(history), state
396
  break
397
 
398
- # UI Layout
399
  custom_css = """
400
- #view_img { height: 600px; object-fit: contain; }
401
  """
402
 
403
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
404
  state = gr.State({})
405
  history = gr.State([])
406
 
407
- gr.Markdown("# 🤖 Fara CUA - Chrome Agent")
408
-
 
409
  with gr.Row():
410
  with gr.Column(scale=1):
411
- task_input = gr.Textbox(label="Task Instruction", value="Go to google.com and search for 'SpaceX'")
412
- run_btn = gr.Button("Run Agent", variant="primary")
413
- clear_btn = gr.Button("Reset / Clear")
 
 
 
 
 
414
 
 
 
 
 
 
 
415
  with gr.Column(scale=2):
416
- browser_view = gr.Image(label="Live Browser View", elem_id="view_img", interactive=False)
417
-
418
- logs_output = gr.Textbox(label="Agent Logs", lines=15, interactive=False)
 
 
 
 
 
419
 
420
- # Event handlers
421
  run_btn.click(
422
- fn=run_task_loop,
423
  inputs=[task_input, history, state],
424
- outputs=[browser_view, logs_output, state]
425
  )
426
 
427
- clear_btn.click(
428
  fn=cleanup_sandbox,
429
  inputs=[state],
430
  outputs=[history, state]
431
  ).then(
432
  lambda: (None, ""),
433
- outputs=[browser_view, logs_output]
434
  )
435
 
436
  if __name__ == "__main__":
437
- demo.launch()
 
3
  import time
4
  import shutil
5
  import uuid
6
+ import json
7
  import tempfile
 
8
  from io import BytesIO
9
+ import threading
10
 
11
  import gradio as gr
 
12
  import torch
13
  import spaces
14
+ from PIL import Image, ImageDraw
15
 
16
  # Transformers imports
17
  from transformers import (
 
33
  # CONSTANTS & CONFIG
34
  # -----------------------------------------------------------------------------
35
 
36
+ MODEL_ID = "microsoft/Fara-7B"
37
+ # Use the Qwen fallback if Fara isn't directly accessible in your environment
38
+ FALLBACK_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
39
+
40
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
  WIDTH = 1024
42
  HEIGHT = 768
 
44
  if not os.path.exists(TMP_DIR):
45
  os.makedirs(TMP_DIR)
46
 
47
+ # Updated System Prompt to match the JSON tool_call format the model prefers
48
+ OS_SYSTEM_PROMPT = """You are a helpful GUI agent controlling a Chrome browser.
49
+ You will be given a screenshot of the current page and a high-level task.
50
+ You need to generate the next action to move towards completing the task.
51
+
52
+ The browser resolution is 1024x768.
53
+
54
+ Output your action in the following XML format containing JSON:
55
+ <tool_call>
56
+ {"name": "Browser", "arguments": { ... }}
57
+ </tool_call>
58
+
59
+ Supported Actions (in 'arguments'):
60
+ 1. Click: {"action": "click", "coordinate": [x, y]}
61
+ (where x and y are integer coordinates based on a 1000x1000 normalized grid)
62
+ 2. Type: {"action": "type_text", "text": "something", "coordinate": [x, y], "press_enter": true}
63
+ (Coordinate is optional but recommended to focus the input field first)
64
+ 3. Scroll: {"action": "scroll", "direction": "down"}
65
+ 4. Navigate: {"action": "navigate", "url": "https://..."}
66
+
67
  Example:
68
+ <tool_call>
69
+ {"name": "Browser", "arguments": {"action": "type_text", "coordinate": [500, 280], "text": "hugging face models", "press_enter": true}}
70
+ </tool_call>
71
  """
72
 
73
  # -----------------------------------------------------------------------------
74
+ # MODEL WRAPPER
75
  # -----------------------------------------------------------------------------
76
 
77
+ class ModelWrapper:
78
  def __init__(self, model_id: str, to_device: str = "cuda"):
79
+ print(f"Loading model: {model_id} on {to_device}...")
80
+ self.device = to_device
81
 
82
  try:
83
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
87
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
88
  device_map="auto" if to_device == "cuda" else None,
89
  )
 
 
 
 
90
  except Exception as e:
91
+ print(f"Primary model load failed ({e}). Loading fallback: {FALLBACK_MODEL_ID}")
92
+ self.processor = AutoProcessor.from_pretrained(FALLBACK_MODEL_ID, trust_remote_code=True)
 
93
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
94
+ FALLBACK_MODEL_ID,
95
  trust_remote_code=True,
96
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
97
+ device_map="auto" if to_device == "cuda" else None,
98
  )
99
+
100
+ if to_device == "cpu":
101
+ self.model.to("cpu")
102
+ self.model.eval()
103
+ print("Model loaded successfully.")
104
 
105
  def generate(self, messages: list[dict], max_new_tokens=512):
 
106
  text = self.processor.apply_chat_template(
107
  messages, tokenize=False, add_generation_prompt=True
108
  )
 
109
  image_inputs, video_inputs = process_vision_info(messages)
110
 
111
  inputs = self.processor(
 
118
  inputs = inputs.to(self.model.device)
119
 
120
  with torch.no_grad():
121
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
122
 
 
123
  generated_ids_trimmed = [
124
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
125
  ]
 
126
  output_text = self.processor.batch_decode(
127
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
128
  )[0]
129
 
130
  return output_text
131
 
132
+ # Initialize Global Model
133
+ model = ModelWrapper(MODEL_ID, DEVICE)
134
 
135
  # -----------------------------------------------------------------------------
136
  # SELENIUM SANDBOX
 
168
 
169
  self.driver = webdriver.Chrome(service=service, options=chrome_opts)
170
  self.driver.set_window_size(width, height)
171
+
172
+ # Start blank
173
+ self.driver.get("about:blank")
174
  print("Selenium started.")
175
  except Exception as e:
176
  print(f"Selenium init failed: {e}")
 
181
  return Image.open(BytesIO(self.driver.get_screenshot_as_png()))
182
 
183
  def execute_action(self, action_data: dict):
184
+ """Execute parsed JSON action on the browser"""
185
+ # Mapping model's JSON structure to Selenium calls
186
+
187
+ args = action_data.get("arguments", {})
188
+ action_type = args.get("action")
189
 
190
  try:
191
  actions = ActionChains(self.driver)
192
  body = self.driver.find_element(By.TAG_NAME, "body")
193
+
194
+ # 1. Handle Coordinate Movement (Common to click/type)
195
+ if "coordinate" in args:
196
+ coords = args["coordinate"]
197
+ # Assuming Fara uses 1000x1000 normalization standard
198
+ x_norm = coords[0] / 1000
199
+ y_norm = coords[1] / 1000
200
+
201
  x_px = int(x_norm * self.width)
202
  y_px = int(y_norm * self.height)
203
+
204
+ # Move mouse
205
  actions.move_to_element_with_offset(body, 0, 0)
206
  actions.move_by_offset(x_px, y_px)
207
+ actions.click() # Focus the element
 
 
 
 
 
208
  actions.perform()
209
 
210
+ # Reset actions queue
211
+ actions = ActionChains(self.driver)
212
+
213
+ # 2. Handle Specific Actions
214
+ if action_type == "navigate":
215
+ url = args.get("url")
216
+ if url:
217
+ if not url.startswith("http"): url = "https://" + url
218
+ self.driver.get(url)
219
+ time.sleep(2)
220
+ return f"Navigated to {url}"
221
+
222
+ elif action_type == "type_text":
223
+ text = args.get("text", "")
224
  actions.send_keys(text)
225
+ if args.get("press_enter", False):
226
+ actions.send_keys(Keys.ENTER)
227
  actions.perform()
228
+ return f"Typed '{text}'"
229
+
230
+ elif action_type == "click":
231
+ # Click is handled in coordinate block above, just return status
232
+ return f"Clicked at {args.get('coordinate')}"
233
+
234
+ elif action_type == "scroll":
235
+ direction = args.get("direction", "down")
236
+ scroll_amount = 300 if direction == "down" else -300
237
+ self.driver.execute_script(f"window.scrollBy(0, {scroll_amount});")
238
+ return f"Scrolled {direction}"
239
+
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  return f"Executed {action_type}"
241
+
242
  except Exception as e:
243
+ print(f"Execution Error: {e}")
244
  return f"Action failed: {e}"
245
 
246
  def cleanup(self):
 
249
  shutil.rmtree(self.tmp_dir, ignore_errors=True)
250
 
251
  # -----------------------------------------------------------------------------
252
+ # PARSER
253
  # -----------------------------------------------------------------------------
254
 
255
+ def parse_model_response(response: str) -> dict:
256
+ """
257
+ Parses <tool_call> JSON content </tool_call>
258
+ Returns a dictionary or None
259
+ """
260
+ # Regex to extract JSON inside tool_call tags
261
+ pattern = r"<tool_call>\s*({.*?})\s*</tool_call>"
262
+ match = re.search(pattern, response, re.DOTALL)
 
 
 
263
 
264
+ if match:
265
+ try:
266
+ json_str = match.group(1)
267
+ data = json.loads(json_str)
268
+ return data
269
+ except json.JSONDecodeError:
270
+ print("Failed to decode JSON from tool_call")
271
+ return None
272
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # -----------------------------------------------------------------------------
275
+ # AGENT LOOP
276
  # -----------------------------------------------------------------------------
277
 
278
+ # Global registry to persist sessions in Gradio
279
+ SANDBOX_REGISTRY = {}
280
+
281
  @spaces.GPU(duration=120)
282
  def agent_step(task_instruction: str, history: list, sandbox_state: dict):
283
+ # Retrieve or create sandbox
284
  if 'uuid' not in sandbox_state:
285
  sandbox_state['uuid'] = str(uuid.uuid4())
 
 
 
 
286
 
287
+ sid = sandbox_state['uuid']
288
+ if sid not in SANDBOX_REGISTRY:
289
+ SANDBOX_REGISTRY[sid] = SeleniumSandbox(WIDTH, HEIGHT)
 
 
290
 
291
+ sandbox = SANDBOX_REGISTRY[sid]
292
 
293
+ # 1. Capture State
294
  screenshot = sandbox.get_screenshot()
295
 
296
+ # 2. Build Messages
297
+ # Fara works best when seeing the history of images, but for memory efficiency
298
+ # in this demo we will just send the current screenshot + text history.
299
 
300
  messages = [
301
+ {"role": "system", "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]},
302
  {
303
+ "role": "user",
 
 
 
 
304
  "content": [
305
  {"type": "image", "image": screenshot},
306
+ {"type": "text", "text": f"Task: {task_instruction}\nPrevious Actions Log:\n" + "\n".join(history[-3:])}
307
  ]
308
  }
309
  ]
310
+
311
+ # 3. Inference
312
  response = model.generate(messages)
313
 
314
+ # 4. Parse & Execute
315
+ action_data = parse_model_response(response)
 
316
 
317
+ log_entry = f"Thought: {response}\n"
318
 
 
 
319
  if action_data:
320
+ result = sandbox.execute_action(action_data)
321
+ log_entry += f"Action: {action_data.get('arguments', {}).get('action')}\nResult: {result}"
322
 
323
+ # Visualize click on screenshot for UI
324
+ args = action_data.get("arguments", {})
325
+ if "coordinate" in args:
326
  draw = ImageDraw.Draw(screenshot)
327
+ coords = args["coordinate"]
328
+ # Map 1000x1000 back to image size
329
+ x = int(coords[0] / 1000 * WIDTH)
330
+ y = int(coords[1] / 1000 * HEIGHT)
331
+ draw.ellipse((x-10, y-10, x+10, y+10), outline="red", width=5)
332
+ else:
333
+ log_entry += "Action: Parsing Failed or No Action"
334
 
 
335
  history.append(log_entry)
336
 
 
337
  return screenshot, history, sandbox_state
338
 
 
 
 
339
  def cleanup_sandbox(sandbox_state):
340
  sid = sandbox_state.get('uuid')
341
  if sid and sid in SANDBOX_REGISTRY:
 
347
  # GRADIO UI
348
  # -----------------------------------------------------------------------------
349
 
350
+ def run_loop(task, history, state):
351
+ MAX_STEPS = 10
352
+ for i in range(MAX_STEPS):
 
 
353
  try:
354
+ img, new_hist, new_state = agent_step(task, history, state)
355
+ history = new_hist
 
356
 
357
+ # Combine history into a readable log
358
+ log_text = "\n" + "="*40 + "\n".join(history)
 
 
 
 
 
 
 
 
359
 
360
+ yield img, log_text, state
361
+ time.sleep(1) # Visual pause
362
  except Exception as e:
363
+ history.append(f"Critical Error: {e}")
 
364
  yield None, "\n".join(history), state
365
  break
366
 
 
367
  custom_css = """
368
+ .browser-img { height: 600px; object-fit: contain; border: 2px solid #333; }
369
  """
370
 
371
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
372
  state = gr.State({})
373
  history = gr.State([])
374
 
375
+ gr.Markdown("# 🌐 Fara CUA - Chrome Agent")
376
+ gr.Markdown("Agent that uses **Microsoft Fara-7B** (Vision) to control a headless Chrome browser.")
377
+
378
  with gr.Row():
379
  with gr.Column(scale=1):
380
+ task_input = gr.Textbox(
381
+ label="Task",
382
+ value="Go to google.com and search for 'Hugging Face models'",
383
+ lines=2
384
+ )
385
+ with gr.Row():
386
+ run_btn = gr.Button("▶ Run Agent", variant="primary")
387
+ reset_btn = gr.Button("⏹ Reset")
388
 
389
+ gr.Examples([
390
+ "Go to google.com and search for 'Hugging Face models'",
391
+ "Navigate to wikipedia.org, type 'Artificial Intelligence' and press enter",
392
+ "Go to bing.com and search for 'SpaceX launch'"
393
+ ], inputs=task_input)
394
+
395
  with gr.Column(scale=2):
396
+ browser_view = gr.Image(
397
+ label="Live Browser View",
398
+ interactive=False,
399
+ elem_classes="browser-img",
400
+ type="pil"
401
+ )
402
+
403
+ logs_out = gr.Textbox(label="Execution Logs", lines=10, interactive=False)
404
 
 
405
  run_btn.click(
406
+ fn=run_loop,
407
  inputs=[task_input, history, state],
408
+ outputs=[browser_view, logs_out, state]
409
  )
410
 
411
+ reset_btn.click(
412
  fn=cleanup_sandbox,
413
  inputs=[state],
414
  outputs=[history, state]
415
  ).then(
416
  lambda: (None, ""),
417
+ outputs=[browser_view, logs_out]
418
  )
419
 
420
  if __name__ == "__main__":
421
+ demo.launch(share=True)