prithivMLmods commited on
Commit
3fa36ec
·
verified ·
1 Parent(s): 17659d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -187
app.py CHANGED
@@ -3,15 +3,16 @@ import re
3
  import time
4
  import shutil
5
  import uuid
6
- import json
7
  import tempfile
 
8
  from io import BytesIO
9
- import threading
10
 
11
  import gradio as gr
 
12
  import torch
13
  import spaces
14
- from PIL import Image, ImageDraw
15
 
16
  # Transformers imports
17
  from transformers import (
@@ -33,10 +34,7 @@ from webdriver_manager.chrome import ChromeDriverManager
33
  # CONSTANTS & CONFIG
34
  # -----------------------------------------------------------------------------
35
 
36
- MODEL_ID = "microsoft/Fara-7B"
37
- # Use the Qwen fallback if Fara isn't directly accessible in your environment
38
- FALLBACK_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
39
-
40
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
  WIDTH = 1024
42
  HEIGHT = 768
@@ -44,40 +42,35 @@ TMP_DIR = "./tmp"
44
  if not os.path.exists(TMP_DIR):
45
  os.makedirs(TMP_DIR)
46
 
47
- # Updated System Prompt to match the JSON tool_call format the model prefers
48
- OS_SYSTEM_PROMPT = """You are a helpful GUI agent controlling a Chrome browser.
49
- You will be given a screenshot of the current page and a high-level task.
50
- You need to generate the next action to move towards completing the task.
51
-
52
- The browser resolution is 1024x768.
53
-
54
- Output your action in the following XML format containing JSON:
55
- <tool_call>
56
- {"name": "Browser", "arguments": { ... }}
57
- </tool_call>
58
-
59
- Supported Actions (in 'arguments'):
60
- 1. Click: {"action": "click", "coordinate": [x, y]}
61
- (where x and y are integer coordinates based on a 1000x1000 normalized grid)
62
- 2. Type: {"action": "type_text", "text": "something", "coordinate": [x, y], "press_enter": true}
63
- (Coordinate is optional but recommended to focus the input field first)
64
- 3. Scroll: {"action": "scroll", "direction": "down"}
65
- 4. Navigate: {"action": "navigate", "url": "https://..."}
66
-
67
  Example:
68
- <tool_call>
69
- {"name": "Browser", "arguments": {"action": "type_text", "coordinate": [500, 280], "text": "hugging face models", "press_enter": true}}
70
- </tool_call>
71
  """
72
 
73
  # -----------------------------------------------------------------------------
74
- # MODEL WRAPPER
75
  # -----------------------------------------------------------------------------
76
 
77
- class ModelWrapper:
78
  def __init__(self, model_id: str, to_device: str = "cuda"):
79
- print(f"Loading model: {model_id} on {to_device}...")
80
- self.device = to_device
81
 
82
  try:
83
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
@@ -87,25 +80,27 @@ class ModelWrapper:
87
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
88
  device_map="auto" if to_device == "cuda" else None,
89
  )
 
 
 
 
90
  except Exception as e:
91
- print(f"Primary model load failed ({e}). Loading fallback: {FALLBACK_MODEL_ID}")
92
- self.processor = AutoProcessor.from_pretrained(FALLBACK_MODEL_ID, trust_remote_code=True)
 
93
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
94
- FALLBACK_MODEL_ID,
95
  trust_remote_code=True,
96
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
97
- device_map="auto" if to_device == "cuda" else None,
98
  )
99
-
100
- if to_device == "cpu":
101
- self.model.to("cpu")
102
- self.model.eval()
103
- print("Model loaded successfully.")
104
 
105
  def generate(self, messages: list[dict], max_new_tokens=512):
 
106
  text = self.processor.apply_chat_template(
107
  messages, tokenize=False, add_generation_prompt=True
108
  )
 
109
  image_inputs, video_inputs = process_vision_info(messages)
110
 
111
  inputs = self.processor(
@@ -118,19 +113,24 @@ class ModelWrapper:
118
  inputs = inputs.to(self.model.device)
119
 
120
  with torch.no_grad():
121
- generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
122
 
 
123
  generated_ids_trimmed = [
124
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
125
  ]
 
126
  output_text = self.processor.batch_decode(
127
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
128
  )[0]
129
 
130
  return output_text
131
 
132
- # Initialize Global Model
133
- model = ModelWrapper(MODEL_ID, DEVICE)
134
 
135
  # -----------------------------------------------------------------------------
136
  # SELENIUM SANDBOX
@@ -168,9 +168,6 @@ class SeleniumSandbox:
168
 
169
  self.driver = webdriver.Chrome(service=service, options=chrome_opts)
170
  self.driver.set_window_size(width, height)
171
-
172
- # Start blank
173
- self.driver.get("about:blank")
174
  print("Selenium started.")
175
  except Exception as e:
176
  print(f"Selenium init failed: {e}")
@@ -181,66 +178,59 @@ class SeleniumSandbox:
181
  return Image.open(BytesIO(self.driver.get_screenshot_as_png()))
182
 
183
  def execute_action(self, action_data: dict):
184
- """Execute parsed JSON action on the browser"""
185
- # Mapping model's JSON structure to Selenium calls
186
-
187
- args = action_data.get("arguments", {})
188
- action_type = args.get("action")
189
 
190
  try:
191
  actions = ActionChains(self.driver)
192
  body = self.driver.find_element(By.TAG_NAME, "body")
193
-
194
- # 1. Handle Coordinate Movement (Common to click/type)
195
- if "coordinate" in args:
196
- coords = args["coordinate"]
197
- # Assuming Fara uses 1000x1000 normalization standard
198
- x_norm = coords[0] / 1000
199
- y_norm = coords[1] / 1000
200
-
201
  x_px = int(x_norm * self.width)
202
  y_px = int(y_norm * self.height)
203
-
204
- # Move mouse
205
  actions.move_to_element_with_offset(body, 0, 0)
206
  actions.move_by_offset(x_px, y_px)
207
- actions.click() # Focus the element
 
 
 
 
 
208
  actions.perform()
209
 
210
- # Reset actions queue
211
- actions = ActionChains(self.driver)
212
-
213
- # 2. Handle Specific Actions
214
- if action_type == "navigate":
215
- url = args.get("url")
216
- if url:
217
- if not url.startswith("http"): url = "https://" + url
218
- self.driver.get(url)
219
- time.sleep(2)
220
- return f"Navigated to {url}"
221
-
222
- elif action_type == "type_text":
223
- text = args.get("text", "")
224
  actions.send_keys(text)
225
- if args.get("press_enter", False):
226
- actions.send_keys(Keys.ENTER)
227
  actions.perform()
228
- return f"Typed '{text}'"
229
-
230
- elif action_type == "click":
231
- # Click is handled in coordinate block above, just return status
232
- return f"Clicked at {args.get('coordinate')}"
233
-
234
- elif action_type == "scroll":
235
- direction = args.get("direction", "down")
236
- scroll_amount = 300 if direction == "down" else -300
237
- self.driver.execute_script(f"window.scrollBy(0, {scroll_amount});")
238
- return f"Scrolled {direction}"
239
-
240
- return f"Executed {action_type}"
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  except Exception as e:
243
- print(f"Execution Error: {e}")
244
  return f"Action failed: {e}"
245
 
246
  def cleanup(self):
@@ -249,93 +239,124 @@ class SeleniumSandbox:
249
  shutil.rmtree(self.tmp_dir, ignore_errors=True)
250
 
251
  # -----------------------------------------------------------------------------
252
- # PARSER
253
  # -----------------------------------------------------------------------------
254
 
255
- def parse_model_response(response: str) -> dict:
256
- """
257
- Parses <tool_call> JSON content </tool_call>
258
- Returns a dictionary or None
259
- """
260
- # Regex to extract JSON inside tool_call tags
261
- pattern = r"<tool_call>\s*({.*?})\s*</tool_call>"
262
- match = re.search(pattern, response, re.DOTALL)
 
 
 
263
 
264
- if match:
265
- try:
266
- json_str = match.group(1)
267
- data = json.loads(json_str)
268
- return data
269
- except json.JSONDecodeError:
270
- print("Failed to decode JSON from tool_call")
271
- return None
272
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # -----------------------------------------------------------------------------
275
- # AGENT LOOP
276
  # -----------------------------------------------------------------------------
277
 
278
- # Global registry to persist sessions in Gradio
279
- SANDBOX_REGISTRY = {}
280
-
281
  @spaces.GPU(duration=120)
282
  def agent_step(task_instruction: str, history: list, sandbox_state: dict):
283
- # Retrieve or create sandbox
284
  if 'uuid' not in sandbox_state:
285
  sandbox_state['uuid'] = str(uuid.uuid4())
 
 
 
 
286
 
287
- sid = sandbox_state['uuid']
288
- if sid not in SANDBOX_REGISTRY:
289
- SANDBOX_REGISTRY[sid] = SeleniumSandbox(WIDTH, HEIGHT)
 
 
290
 
291
- sandbox = SANDBOX_REGISTRY[sid]
292
 
293
- # 1. Capture State
294
  screenshot = sandbox.get_screenshot()
295
 
296
- # 2. Build Messages
297
- # Fara works best when seeing the history of images, but for memory efficiency
298
- # in this demo we will just send the current screenshot + text history.
299
 
300
  messages = [
301
- {"role": "system", "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]},
302
  {
303
- "role": "user",
 
 
 
 
304
  "content": [
305
  {"type": "image", "image": screenshot},
306
- {"type": "text", "text": f"Task: {task_instruction}\nPrevious Actions Log:\n" + "\n".join(history[-3:])}
307
  ]
308
  }
309
  ]
310
-
311
- # 3. Inference
312
  response = model.generate(messages)
313
 
314
- # 4. Parse & Execute
315
- action_data = parse_model_response(response)
 
316
 
317
- log_entry = f"Thought: {response}\n"
318
 
 
 
319
  if action_data:
320
- result = sandbox.execute_action(action_data)
321
- log_entry += f"Action: {action_data.get('arguments', {}).get('action')}\nResult: {result}"
322
 
323
- # Visualize click on screenshot for UI
324
- args = action_data.get("arguments", {})
325
- if "coordinate" in args:
326
  draw = ImageDraw.Draw(screenshot)
327
- coords = args["coordinate"]
328
- # Map 1000x1000 back to image size
329
- x = int(coords[0] / 1000 * WIDTH)
330
- y = int(coords[1] / 1000 * HEIGHT)
331
- draw.ellipse((x-10, y-10, x+10, y+10), outline="red", width=5)
332
- else:
333
- log_entry += "Action: Parsing Failed or No Action"
334
 
 
335
  history.append(log_entry)
336
 
 
337
  return screenshot, history, sandbox_state
338
 
 
 
 
339
  def cleanup_sandbox(sandbox_state):
340
  sid = sandbox_state.get('uuid')
341
  if sid and sid in SANDBOX_REGISTRY:
@@ -347,75 +368,70 @@ def cleanup_sandbox(sandbox_state):
347
  # GRADIO UI
348
  # -----------------------------------------------------------------------------
349
 
350
- def run_loop(task, history, state):
351
- MAX_STEPS = 10
352
- for i in range(MAX_STEPS):
 
 
353
  try:
354
- img, new_hist, new_state = agent_step(task, history, state)
355
- history = new_hist
 
356
 
357
- # Combine history into a readable log
358
- log_text = "\n" + "="*40 + "\n".join(history)
 
 
 
 
 
 
 
 
359
 
360
- yield img, log_text, state
361
- time.sleep(1) # Visual pause
362
  except Exception as e:
363
- history.append(f"Critical Error: {e}")
 
364
  yield None, "\n".join(history), state
365
  break
366
 
 
367
  custom_css = """
368
- .browser-img { height: 600px; object-fit: contain; border: 2px solid #333; }
369
  """
370
 
371
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
372
  state = gr.State({})
373
  history = gr.State([])
374
 
375
- gr.Markdown("# 🌐 Fara CUA - Chrome Agent")
376
- gr.Markdown("Agent that uses **Microsoft Fara-7B** (Vision) to control a headless Chrome browser.")
377
-
378
  with gr.Row():
379
  with gr.Column(scale=1):
380
- task_input = gr.Textbox(
381
- label="Task",
382
- value="Go to google.com and search for 'Hugging Face models'",
383
- lines=2
384
- )
385
- with gr.Row():
386
- run_btn = gr.Button("▶ Run Agent", variant="primary")
387
- reset_btn = gr.Button("⏹ Reset")
388
 
389
- gr.Examples([
390
- "Go to google.com and search for 'Hugging Face models'",
391
- "Navigate to wikipedia.org, type 'Artificial Intelligence' and press enter",
392
- "Go to bing.com and search for 'SpaceX launch'"
393
- ], inputs=task_input)
394
-
395
  with gr.Column(scale=2):
396
- browser_view = gr.Image(
397
- label="Live Browser View",
398
- interactive=False,
399
- elem_classes="browser-img",
400
- type="pil"
401
- )
402
-
403
- logs_out = gr.Textbox(label="Execution Logs", lines=10, interactive=False)
404
 
 
405
  run_btn.click(
406
- fn=run_loop,
407
  inputs=[task_input, history, state],
408
- outputs=[browser_view, logs_out, state]
409
  )
410
 
411
- reset_btn.click(
412
  fn=cleanup_sandbox,
413
  inputs=[state],
414
  outputs=[history, state]
415
  ).then(
416
  lambda: (None, ""),
417
- outputs=[browser_view, logs_out]
418
  )
419
 
420
  if __name__ == "__main__":
421
- demo.launch(share=True)
 
3
  import time
4
  import shutil
5
  import uuid
 
6
  import tempfile
7
+ import unicodedata
8
  from io import BytesIO
9
+ from typing import Tuple, Optional, List, Dict, Any
10
 
11
  import gradio as gr
12
+ import numpy as np
13
  import torch
14
  import spaces
15
+ from PIL import Image, ImageDraw, ImageFont
16
 
17
  # Transformers imports
18
  from transformers import (
 
34
  # CONSTANTS & CONFIG
35
  # -----------------------------------------------------------------------------
36
 
37
+ MODEL_ID = "microsoft/Fara-7B" # Or your specific Fara model repo
 
 
 
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
  WIDTH = 1024
40
  HEIGHT = 768
 
42
  if not os.path.exists(TMP_DIR):
43
  os.makedirs(TMP_DIR)
44
 
45
+ # System Prompt adapted for Fara/GUI agents
46
+ OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
47
+ You need to generate the next action to complete the task.
48
+
49
+ Supported actions:
50
+ 1. `click(x=0.5, y=0.5)`: Click at the specific location.
51
+ 2. `right_click(x=0.5, y=0.5)`: Right click at the specific location.
52
+ 3. `double_click(x=0.5, y=0.5)`: Double click at the specific location.
53
+ 4. `type_text(text="hello")`: Type the text.
54
+ 5. `scroll(amount=2, direction="down")`: Scroll the page.
55
+ 6. `press_key(key="enter")`: Press a specific key.
56
+ 7. `open_url(url="https://google.com")`: Open a specific URL.
57
+
58
+ Output format:
59
+ Please wrap the action code in <code> </code> tags.
 
 
 
 
 
60
  Example:
61
+ <code>
62
+ click(x=0.23, y=0.45)
63
+ </code>
64
  """
65
 
66
  # -----------------------------------------------------------------------------
67
+ # MODEL WRAPPER (Replacing smolagents)
68
  # -----------------------------------------------------------------------------
69
 
70
+ class FaraModelWrapper:
71
  def __init__(self, model_id: str, to_device: str = "cuda"):
72
+ print(f"Loading {model_id} on {to_device}...")
73
+ self.model_id = model_id
74
 
75
  try:
76
  self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
80
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
81
  device_map="auto" if to_device == "cuda" else None,
82
  )
83
+ if to_device == "cpu":
84
+ self.model.to("cpu")
85
+ self.model.eval()
86
+ print("Model loaded successfully.")
87
  except Exception as e:
88
+ print(f"Failed to load Fara, falling back to Qwen2.5-VL-7B for demo compatibility. Error: {e}")
89
+ fallback_id = "Qwen/Qwen2.5-VL-7B-Instruct"
90
+ self.processor = AutoProcessor.from_pretrained(fallback_id, trust_remote_code=True)
91
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
92
+ fallback_id,
93
  trust_remote_code=True,
94
  torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
95
+ device_map="auto",
96
  )
 
 
 
 
 
97
 
98
  def generate(self, messages: list[dict], max_new_tokens=512):
99
+ # Prepare inputs for Fara/QwenVL
100
  text = self.processor.apply_chat_template(
101
  messages, tokenize=False, add_generation_prompt=True
102
  )
103
+
104
  image_inputs, video_inputs = process_vision_info(messages)
105
 
106
  inputs = self.processor(
 
113
  inputs = inputs.to(self.model.device)
114
 
115
  with torch.no_grad():
116
+ generated_ids = self.model.generate(
117
+ **inputs,
118
+ max_new_tokens=max_new_tokens
119
+ )
120
 
121
+ # Trim input tokens
122
  generated_ids_trimmed = [
123
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
124
  ]
125
+
126
  output_text = self.processor.batch_decode(
127
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
128
  )[0]
129
 
130
  return output_text
131
 
132
+ # Initialize global model
133
+ model = FaraModelWrapper(MODEL_ID, DEVICE)
134
 
135
  # -----------------------------------------------------------------------------
136
  # SELENIUM SANDBOX
 
168
 
169
  self.driver = webdriver.Chrome(service=service, options=chrome_opts)
170
  self.driver.set_window_size(width, height)
 
 
 
171
  print("Selenium started.")
172
  except Exception as e:
173
  print(f"Selenium init failed: {e}")
 
178
  return Image.open(BytesIO(self.driver.get_screenshot_as_png()))
179
 
180
  def execute_action(self, action_data: dict):
181
+ """Execute parsed action on the browser"""
182
+ action_type = action_data.get('type')
 
 
 
183
 
184
  try:
185
  actions = ActionChains(self.driver)
186
  body = self.driver.find_element(By.TAG_NAME, "body")
187
+
188
+ # Helper to move to coordinates
189
+ def move_to(x_norm, y_norm):
190
+ # Convert normalized (0-1) to pixel coordinates
 
 
 
 
191
  x_px = int(x_norm * self.width)
192
  y_px = int(y_norm * self.height)
 
 
193
  actions.move_to_element_with_offset(body, 0, 0)
194
  actions.move_by_offset(x_px, y_px)
195
+
196
+ if action_type in ['click', 'right_click', 'double_click']:
197
+ move_to(action_data['x'], action_data['y'])
198
+ if action_type == 'click': actions.click()
199
+ elif action_type == 'right_click': actions.context_click()
200
+ elif action_type == 'double_click': actions.double_click()
201
  actions.perform()
202
 
203
+ elif action_type == 'type_text':
204
+ text = action_data.get('text', '')
 
 
 
 
 
 
 
 
 
 
 
 
205
  actions.send_keys(text)
 
 
206
  actions.perform()
207
+
208
+ elif action_type == 'press_key':
209
+ key_name = action_data.get('key', '').lower()
210
+ k = getattr(Keys, key_name.upper(), None)
211
+ if not k:
212
+ if key_name == "enter": k = Keys.ENTER
213
+ elif key_name == "space": k = Keys.SPACE
214
+ elif key_name == "backspace": k = Keys.BACK_SPACE
215
+ if k:
216
+ actions.send_keys(k)
217
+ actions.perform()
 
 
218
 
219
+ elif action_type == 'scroll':
220
+ amount = action_data.get('amount', 2)
221
+ direction = action_data.get('direction', 'down')
222
+ scroll_y = amount * 100
223
+ if direction == 'up': scroll_y = -scroll_y
224
+ self.driver.execute_script(f"window.scrollBy(0, {scroll_y});")
225
+
226
+ elif action_type == 'open_url':
227
+ url = action_data.get('url', '')
228
+ if not url.startswith('http'): url = 'https://' + url
229
+ self.driver.get(url)
230
+ time.sleep(2)
231
+
232
+ return f"Executed {action_type}"
233
  except Exception as e:
 
234
  return f"Action failed: {e}"
235
 
236
  def cleanup(self):
 
239
  shutil.rmtree(self.tmp_dir, ignore_errors=True)
240
 
241
  # -----------------------------------------------------------------------------
242
+ # PARSING LOGIC
243
  # -----------------------------------------------------------------------------
244
 
245
+ def parse_code_block(response: str) -> str:
246
+ pattern = r"<code>\s*(.*?)\s*</code>"
247
+ matches = re.findall(pattern, response, re.DOTALL)
248
+ if matches:
249
+ return matches[-1].strip() # Return the last code block
250
+ return ""
251
+
252
+ def parse_action_string(action_str: str) -> dict:
253
+ """Parse string like 'click(x=0.5, y=0.5)' into a dict"""
254
+ # Simple regex parsing for demonstration
255
+ action_data = {}
256
 
257
+ # 1. Coordinate actions: name(x=..., y=...)
258
+ coord_match = re.match(r"(\w+)\s*\(\s*x\s*=\s*([0-9.]+)\s*,\s*y\s*=\s*([0-9.]+)\s*\)", action_str)
259
+ if coord_match:
260
+ return {
261
+ "type": coord_match.group(1),
262
+ "x": float(coord_match.group(2)),
263
+ "y": float(coord_match.group(3))
264
+ }
265
+
266
+ # 2. Open URL: open_url(url="...")
267
+ url_match = re.match(r"open_url\s*\(\s*url\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
268
+ if url_match:
269
+ return {"type": "open_url", "url": url_match.group(1)}
270
+
271
+ # 3. Type text: type_text(text="...")
272
+ text_match = re.match(r"type_text\s*\(\s*text\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
273
+ if text_match:
274
+ return {"type": "type_text", "text": text_match.group(1)}
275
+
276
+ # 4. Press key: press_key(key="...")
277
+ key_match = re.match(r"press_key\s*\(\s*key\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
278
+ if key_match:
279
+ return {"type": "press_key", "key": key_match.group(1)}
280
+
281
+ # 5. Scroll: scroll(amount=..., direction="...")
282
+ if "scroll" in action_str:
283
+ return {"type": "scroll", "amount": 2, "direction": "down"} # Default
284
+
285
+ return {}
286
 
287
  # -----------------------------------------------------------------------------
288
+ # MAIN LOOP
289
  # -----------------------------------------------------------------------------
290
 
 
 
 
291
  @spaces.GPU(duration=120)
292
  def agent_step(task_instruction: str, history: list, sandbox_state: dict):
293
+ # Initialize sandbox if needed (handled via state in Gradio mostly, but for safety)
294
  if 'uuid' not in sandbox_state:
295
  sandbox_state['uuid'] = str(uuid.uuid4())
296
+ sandbox = SeleniumSandbox(WIDTH, HEIGHT)
297
+ # Store sandbox instance reference globally or handle cleanup carefully
298
+ # For this demo, we'll recreate/attach to session based on state if persisting,
299
+ # but here we'll assume a persistent session for the run.
300
 
301
+ # HACK: For Gradio state persistence with objects that can't be pickled easily,
302
+ # we often use a global dict mapping UUID -> Sandbox
303
+ sandbox_id = sandbox_state['uuid']
304
+ if sandbox_id not in SANDBOX_REGISTRY:
305
+ SANDBOX_REGISTRY[sandbox_id] = SeleniumSandbox(WIDTH, HEIGHT)
306
 
307
+ sandbox = SANDBOX_REGISTRY[sandbox_id]
308
 
309
+ # 1. Get Screenshot
310
  screenshot = sandbox.get_screenshot()
311
 
312
+ # 2. Construct Prompt
313
+ # Convert history text to string context if needed
 
314
 
315
  messages = [
 
316
  {
317
+ "role": "system",
318
+ "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]
319
+ },
320
+ {
321
+ "role": "user",
322
  "content": [
323
  {"type": "image", "image": screenshot},
324
+ {"type": "text", "text": f"Instruction: {task_instruction}\nPrevious Actions: {history[-1] if history else 'None'}"}
325
  ]
326
  }
327
  ]
328
+
329
+ # 3. Model Inference
330
  response = model.generate(messages)
331
 
332
+ # 4. Parse Action
333
+ action_code = parse_code_block(response)
334
+ action_data = parse_action_string(action_code)
335
 
336
+ log_entry = f"Step: {len(history)+1}\nModel Thought: {response}\nAction: {action_code}"
337
 
338
+ # 5. Execute Action
339
+ execution_result = "No valid action found"
340
  if action_data:
341
+ execution_result = sandbox.execute_action(action_data)
 
342
 
343
+ # Draw marker if coordinate action
344
+ if 'x' in action_data:
 
345
  draw = ImageDraw.Draw(screenshot)
346
+ x_px = action_data['x'] * WIDTH
347
+ y_px = action_data['y'] * HEIGHT
348
+ r = 10
349
+ draw.ellipse((x_px-r, y_px-r, x_px+r, y_px+r), outline="red", width=3)
 
 
 
350
 
351
+ log_entry += f"\nResult: {execution_result}"
352
  history.append(log_entry)
353
 
354
+ # Return updated screenshot and history
355
  return screenshot, history, sandbox_state
356
 
357
+ # Global registry for sandboxes
358
+ SANDBOX_REGISTRY = {}
359
+
360
  def cleanup_sandbox(sandbox_state):
361
  sid = sandbox_state.get('uuid')
362
  if sid and sid in SANDBOX_REGISTRY:
 
368
  # GRADIO UI
369
  # -----------------------------------------------------------------------------
370
 
371
+ def run_task_loop(task, history, state):
372
+ # This generator function runs the agent loop
373
+ max_steps = 10
374
+
375
+ for i in range(max_steps):
376
  try:
377
+ # Run one step
378
+ screenshot, new_history, new_state = agent_step(task, history, state)
379
+ history = new_history
380
 
381
+ # Yield updates to UI
382
+ # We yield the logs (joined) and the latest image
383
+ logs_text = "\n\n" + "-"*40 + "\n\n".join(history)
384
+ yield screenshot, logs_text, state
385
+
386
+ # Check for termination (simplistic)
387
+ if "Done" in history[-1] or "finished" in history[-1].lower():
388
+ break
389
+
390
+ time.sleep(1) # Pause for visual effect
391
 
 
 
392
  except Exception as e:
393
+ error_msg = f"Error in loop: {e}"
394
+ history.append(error_msg)
395
  yield None, "\n".join(history), state
396
  break
397
 
398
+ # UI Layout
399
  custom_css = """
400
+ #view_img { height: 600px; object-fit: contain; }
401
  """
402
 
403
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
404
  state = gr.State({})
405
  history = gr.State([])
406
 
407
+ gr.Markdown("# 🤖 Fara CUA - Chrome Agent")
408
+
 
409
  with gr.Row():
410
  with gr.Column(scale=1):
411
+ task_input = gr.Textbox(label="Task Instruction", value="Go to google.com and search for 'SpaceX'")
412
+ run_btn = gr.Button("Run Agent", variant="primary")
413
+ clear_btn = gr.Button("Reset / Clear")
 
 
 
 
 
414
 
 
 
 
 
 
 
415
  with gr.Column(scale=2):
416
+ browser_view = gr.Image(label="Live Browser View", elem_id="view_img", interactive=False)
417
+
418
+ logs_output = gr.Textbox(label="Agent Logs", lines=15, interactive=False)
 
 
 
 
 
419
 
420
+ # Event handlers
421
  run_btn.click(
422
+ fn=run_task_loop,
423
  inputs=[task_input, history, state],
424
+ outputs=[browser_view, logs_output, state]
425
  )
426
 
427
+ clear_btn.click(
428
  fn=cleanup_sandbox,
429
  inputs=[state],
430
  outputs=[history, state]
431
  ).then(
432
  lambda: (None, ""),
433
+ outputs=[browser_view, logs_output]
434
  )
435
 
436
  if __name__ == "__main__":
437
+ demo.launch()