prithivMLmods commited on
Commit
7e47e30
·
verified ·
1 Parent(s): 640c489

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -641
app.py CHANGED
@@ -1,21 +1,25 @@
1
- import json
2
  import os
3
- import shutil
4
  import time
 
5
  import uuid
6
  import tempfile
7
- import atexit
8
  import unicodedata
9
  from io import BytesIO
10
- from threading import Timer
11
- from typing import Any, Dict, List, Optional
12
- from datetime import datetime
13
 
14
  import gradio as gr
 
15
  import torch
16
  import spaces
17
- from dotenv import load_dotenv
18
- from PIL import Image, ImageDraw
 
 
 
 
 
 
19
 
20
  # Selenium Imports
21
  from selenium import webdriver
@@ -26,190 +30,116 @@ from selenium.webdriver.common.by import By
26
  from selenium.webdriver.common.keys import Keys
27
  from webdriver_manager.chrome import ChromeDriverManager
28
 
29
- # Smolagents imports
30
- from smolagents import CodeAgent, tool, AgentImage
31
- from smolagents.memory import ActionStep, TaskStep
32
- from smolagents.models import ChatMessage, Model, MessageRole
33
- from smolagents.gradio_ui import GradioUI, stream_to_gradio
34
- from smolagents.monitoring import LogLevel
35
-
36
- # Transformers for Fara Model
37
- from transformers import (
38
- Qwen2_5_VLForConditionalGeneration,
39
- AutoProcessor,
40
- )
41
- from qwen_vl_utils import process_vision_info
42
-
43
- load_dotenv(override=True)
44
-
45
  # -----------------------------------------------------------------------------
46
- # CONFIGURATION & CONSTANTS
47
  # -----------------------------------------------------------------------------
48
 
49
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY")
50
- if HF_TOKEN:
51
- from huggingface_hub import login
52
- login(token=HF_TOKEN)
53
-
54
- # Browser Sandbox Config
55
  WIDTH = 1024
56
  HEIGHT = 768
57
- TMP_DIR = "./tmp/"
58
  if not os.path.exists(TMP_DIR):
59
  os.makedirs(TMP_DIR)
60
 
61
- # -----------------------------------------------------------------------------
62
- # MODEL INITIALIZATION (Fara-7B / Qwen2.5-VL)
63
- # -----------------------------------------------------------------------------
64
-
65
- print("Loading Fara Model... This may take a moment.")
66
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
- MODEL_ID_F = "microsoft/Fara-7B"
68
-
69
- # Global model variables
70
- model_f = None
71
- processor_f = None
72
-
73
- try:
74
- processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
75
- model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76
- MODEL_ID_F,
77
- trust_remote_code=True,
78
- torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
79
- device_map="auto",
80
- )
81
- print(f"Fara Model loaded successfully on {DEVICE}")
82
- except Exception as e:
83
- print(f"Error loading Fara Model: {e}")
84
- print("Falling back to Qwen/Qwen2.5-VL-7B-Instruct...")
85
- try:
86
- MODEL_ID_F = "Qwen/Qwen2.5-VL-7B-Instruct"
87
- processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True)
88
- model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained(
89
- MODEL_ID_F,
90
- trust_remote_code=True,
91
- torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
92
- device_map="auto",
93
- )
94
- print(f"Fallback Model ({MODEL_ID_F}) loaded successfully.")
95
- except Exception as inner_e:
96
- print(f"Critical error loading model: {inner_e}")
97
-
98
 
99
  # -----------------------------------------------------------------------------
100
- # GPU ISOLATED INFERENCE FUNCTION
101
  # -----------------------------------------------------------------------------
102
 
103
- @spaces.GPU(duration=120)
104
- def run_model_inference(formatted_messages, max_tokens=1024, stop_sequences=None):
105
- """
106
- Runs inference on the GPU worker.
107
- """
108
- global model_f, processor_f
109
-
110
- if model_f is None:
111
- raise ValueError("Model is not loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- text = processor_f.apply_chat_template(
114
- formatted_messages, tokenize=False, add_generation_prompt=True
115
- )
116
-
117
- image_inputs, video_inputs = process_vision_info(formatted_messages)
118
-
119
- inputs = processor_f(
120
- text=[text],
121
- images=image_inputs,
122
- videos=video_inputs,
123
- padding=True,
124
- return_tensors="pt",
125
- )
126
-
127
- inputs = inputs.to(model_f.device)
128
-
129
- with torch.no_grad():
130
- generated_ids = model_f.generate(
131
- **inputs,
132
- max_new_tokens=max_tokens,
133
- stop_strings=stop_sequences,
134
- tokenizer=processor_f.tokenizer,
135
  )
136
-
137
- generated_ids_trimmed = [
138
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
139
- ]
140
- output_text = processor_f.batch_decode(
141
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
142
- )[0]
143
-
144
- return output_text
145
-
146
-
147
- class FaraLocalModel(Model):
148
- """
149
- Wrapper for the local Fara (Qwen2.5-VL) model to work with SmolAgents.
150
- """
151
- def __init__(self, **kwargs):
152
- super().__init__(**kwargs)
153
-
154
- def __call__(
155
- self,
156
- messages: List[ChatMessage],
157
- stop_sequences: Optional[List[str]] = None,
158
- grammar: Optional[str] = None,
159
- **kwargs,
160
- ) -> ChatMessage:
161
 
162
- formatted_messages = []
163
 
164
- for msg in messages:
165
- # Safely access role and content from ChatMessage object using attributes
166
- role = msg.role if hasattr(msg, "role") else "user"
167
- content = msg.content if hasattr(msg, "content") else ""
168
-
169
- new_content = []
170
-
171
- if isinstance(content, str):
172
- new_content.append({"type": "text", "text": content})
173
- elif isinstance(content, list):
174
- for item in content:
175
- if isinstance(item, str):
176
- new_content.append({"type": "text", "text": item})
177
- elif isinstance(item, dict):
178
- if "type" in item:
179
- if item["type"] == "image":
180
- # Handle path or url - extract value to ensure serializability
181
- val = item.get("image") or item.get("url") or item.get("path")
182
- new_content.append({"type": "image", "image": val})
183
- else:
184
- new_content.append(item)
185
-
186
- formatted_messages.append({"role": role, "content": new_content})
187
-
188
- output_text = run_model_inference(
189
- formatted_messages=formatted_messages,
190
- max_tokens=kwargs.get("max_tokens", 1024),
191
- stop_sequences=stop_sequences
192
  )
 
193
 
194
- return ChatMessage(
195
- role=MessageRole.ASSISTANT,
196
- content=output_text,
197
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  # -----------------------------------------------------------------------------
200
- # SELENIUM CHROME SANDBOX
201
  # -----------------------------------------------------------------------------
202
 
203
  def get_system_chrome_path():
204
- # Common paths for chromium in Linux/HF Spaces
205
- paths = [
206
- "/usr/bin/chromium",
207
- "/usr/bin/chromium-browser",
208
- "/usr/bin/google-chrome",
209
- ]
210
  for p in paths:
211
- if os.path.exists(p):
212
- return p
213
  return None
214
 
215
  class SeleniumSandbox:
@@ -218,535 +148,290 @@ class SeleniumSandbox:
218
  self.height = height
219
  self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_")
220
 
221
- # Setup Chrome Options
222
  chrome_opts = ChromeOptions()
223
-
224
- # Use system binary if available (fixes status 127 in HF Spaces)
225
  binary_path = get_system_chrome_path()
226
- if binary_path:
227
- print(f"Using system Chrome binary at: {binary_path}")
228
- chrome_opts.binary_location = binary_path
229
 
230
  chrome_opts.add_argument("--headless=new")
231
  chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}")
232
  chrome_opts.add_argument(f"--window-size={width},{height}")
233
- chrome_opts.add_argument("--no-sandbox") # Crucial for containers
234
- chrome_opts.add_argument("--disable-dev-shm-usage") # Crucial for containers
235
  chrome_opts.add_argument("--disable-gpu")
236
- chrome_opts.add_argument("--disable-extensions")
237
 
238
- # Initialize Driver
239
  try:
240
- # Check for system driver first
241
  system_driver_path = "/usr/bin/chromedriver"
242
  if os.path.exists(system_driver_path):
243
- print(f"Using system ChromeDriver at: {system_driver_path}")
244
  service = ChromeService(executable_path=system_driver_path)
245
  else:
246
- print("Using webdriver_manager to install ChromeDriver...")
247
  service = ChromeService(ChromeDriverManager().install())
248
 
249
  self.driver = webdriver.Chrome(service=service, options=chrome_opts)
250
  self.driver.set_window_size(width, height)
251
- self.driver.get("about:blank")
252
- print(f"Selenium Chrome Driver started successfully.")
253
-
254
  except Exception as e:
255
- print(f"Failed to initialize Selenium: {e}")
256
- self.cleanup()
257
  raise e
258
 
259
  def get_screenshot(self):
260
- """Returns screenshot as PIL Image"""
261
- png_data = self.driver.get_screenshot_as_png()
262
- return Image.open(BytesIO(png_data))
263
 
264
- def move_mouse_and_click(self, x, y, click_type="left"):
265
- try:
266
- body = self.driver.find_element(By.TAG_NAME, "body")
267
- actions = ActionChains(self.driver)
268
- actions.move_to_element_with_offset(body, 0, 0)
269
- actions.move_by_offset(x, y)
270
- if click_type == "left":
271
- actions.click()
272
- elif click_type == "right":
273
- actions.context_click()
274
- elif click_type == "double":
275
- actions.double_click()
276
- actions.perform()
277
- except Exception as e:
278
- print(f"Error in move_mouse_and_click: {e}")
279
-
280
- def drag_and_drop(self, x1, y1, x2, y2):
281
- try:
282
- body = self.driver.find_element(By.TAG_NAME, "body")
283
- actions = ActionChains(self.driver)
284
- actions.move_to_element_with_offset(body, 0, 0)
285
- actions.move_by_offset(x1, y1)
286
- actions.click_and_hold()
287
- actions.move_by_offset(x2 - x1, y2 - y1)
288
- actions.release()
289
- actions.perform()
290
- except Exception as e:
291
- print(f"Error in drag_and_drop: {e}")
292
-
293
- def type_text(self, text):
294
- actions = ActionChains(self.driver)
295
- actions.send_keys(text)
296
- actions.perform()
297
 
298
- def press_key(self, key_name):
299
  try:
300
- k = getattr(Keys, key_name.upper(), None)
301
- if not k:
302
- if key_name.lower() == "enter": k = Keys.ENTER
303
- elif key_name.lower() == "space": k = Keys.SPACE
304
- elif key_name.lower() == "backspace": k = Keys.BACK_SPACE
305
- elif key_name.lower() == "esc": k = Keys.ESCAPE
306
- else: k = key_name
307
  actions = ActionChains(self.driver)
308
- actions.send_keys(k)
309
- actions.perform()
310
- except Exception as e:
311
- print(f"Error pressing key: {e}")
312
-
313
- def scroll(self, amount, direction="down"):
314
- try:
315
- scroll_y = amount * 100
316
- if direction == "up":
317
- scroll_y = -scroll_y
318
- self.driver.execute_script(f"window.scrollBy(0, {scroll_y});")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  except Exception as e:
320
- print(f"Error scrolling: {e}")
321
 
322
  def cleanup(self):
323
- try:
324
- if hasattr(self, 'driver'):
325
- self.driver.quit()
326
- except:
327
- pass
328
  shutil.rmtree(self.tmp_dir, ignore_errors=True)
329
 
330
  # -----------------------------------------------------------------------------
331
- # AGENT SETUP
332
  # -----------------------------------------------------------------------------
333
 
334
- SYSTEM_PROMPT_TEMPLATE = """You are a browser automation assistant controlling a Google Chrome web browser. The current date is <<current_date>>.
335
-
336
- <action process>
337
- You will be given a task to solve in several steps. At each step you will perform an action.
338
- After each action, you'll receive an updated screenshot of the browser.
339
- Then you will proceed as follows, with these sections: don't skip any!
340
-
341
- Short term goal: ...
342
- What I see: ...
343
- Reflection: ...
344
- Action:
345
- ```python
346
- click(254, 308)
347
- ```<end_code>
348
-
349
- Always format your action ('Action:' part) as Python code blocks as shown above.
350
- </action_process>
351
-
352
- <tools>
353
- On top of performing computations in the Python code snippets that you create, you only have access to these tools to interact with the browser:
354
- {%- for tool in tools.values() %}
355
- - {{ tool.name }}: {{ tool.description }}
356
- Takes inputs: {{tool.inputs}}
357
- Returns an output of type: {{tool.output_type}}
358
- {%- endfor %}
359
- </tools>
360
-
361
- <click_guidelines>
362
- The browser has a resolution of <<resolution_x>>x<<resolution_y>> pixels.
363
- NEVER USE HYPOTHETIC OR ASSUMED COORDINATES, USE TRUE COORDINATES that you can see from the screenshot.
364
- Use precise coordinates based on the current screenshot.
365
- Whenever you click, MAKE SURE to click in the middle of the button, text, link or any other clickable element.
366
- In the screenshot you will see a green crosshair displayed over the position of your last click.
367
- </click_guidelines>
368
-
369
- <general_guidelines>
370
- Execute one action at a time.
371
- Use `open_url` to navigate to websites.
372
- Use `click` to navigate links and interface elements.
373
- Use `type_text` to input into forms.
374
- Use `scroll` to see more content.
375
- If you get stuck, try using `open_url` to search on Google.
376
- </general_guidelines>
377
- """.replace("<<current_date>>", datetime.now().strftime("%A, %d-%B-%Y"))
378
-
379
- def draw_marker_on_image(image_copy, click_coordinates):
380
- x, y = click_coordinates
381
- draw = ImageDraw.Draw(image_copy)
382
- cross_size, linewidth = 10, 3
383
- # Draw cross
384
- draw.line((x - cross_size, y, x + cross_size, y), fill="green", width=linewidth)
385
- draw.line((x, y - cross_size, x, y + cross_size), fill="green", width=linewidth)
386
- draw.ellipse(
387
- (x - cross_size * 2, y - cross_size * 2, x + cross_size * 2, y + cross_size * 2),
388
- outline="green",
389
- width=linewidth,
390
- )
391
- return image_copy
392
-
393
- class SeleniumVisionAgent(CodeAgent):
394
- """Agent for Browser automation with Selenium and Vision"""
395
-
396
- def __init__(
397
- self,
398
- model: Model,
399
- data_dir: str,
400
- sandbox: SeleniumSandbox,
401
- max_steps: int = 20,
402
- verbosity_level: LogLevel = 2,
403
- **kwargs,
404
- ):
405
- self.sandbox = sandbox
406
- self.data_dir = data_dir
407
-
408
- # Initialize
409
- print(f"Browser size: {self.sandbox.width}x{self.sandbox.height}")
410
- os.makedirs(self.data_dir, exist_ok=True)
411
-
412
- # Build tools list
413
- tools_list = self.build_tools()
414
-
415
- super().__init__(
416
- tools=tools_list,
417
- model=model,
418
- max_steps=max_steps,
419
- verbosity_level=verbosity_level,
420
- step_callbacks=[self.take_screenshot_callback],
421
- **kwargs,
422
- )
423
 
424
- self.prompt_templates["system_prompt"] = SYSTEM_PROMPT_TEMPLATE.replace(
425
- "<<resolution_x>>", str(self.sandbox.width)
426
- ).replace("<<resolution_y>>", str(self.sandbox.height))
427
-
428
- def build_tools(self):
429
- """Define and return the list of tools for this agent"""
 
 
430
 
431
- @tool
432
- def click(x: int, y: int) -> str:
433
- """
434
- Performs a left-click at the specified coordinates.
435
- Args:
436
- x: The x coordinate (horizontal position).
437
- y: The y coordinate (vertical position).
438
- """
439
- self.sandbox.move_mouse_and_click(x, y, "left")
440
- self.click_coordinates = [x, y]
441
- return f"Clicked at ({x}, {y})"
442
-
443
- @tool
444
- def right_click(x: int, y: int) -> str:
445
- """
446
- Performs a right-click at the specified coordinates.
447
- Args:
448
- x: The x coordinate.
449
- y: The y coordinate.
450
- """
451
- self.sandbox.move_mouse_and_click(x, y, "right")
452
- self.click_coordinates = [x, y]
453
- return f"Right-clicked at ({x}, {y})"
454
-
455
- @tool
456
- def double_click(x: int, y: int) -> str:
457
- """
458
- Performs a double-click at the specified coordinates.
459
- Args:
460
- x: The x coordinate.
461
- y: The y coordinate.
462
- """
463
- self.sandbox.move_mouse_and_click(x, y, "double")
464
- self.click_coordinates = [x, y]
465
- return f"Double-clicked at ({x}, {y})"
466
-
467
- @tool
468
- def type_text(text: str) -> str:
469
- """
470
- Types the specified text.
471
- Args:
472
- text: The text to type.
473
- """
474
- clean_text = unicodedata.normalize("NFD", text)
475
- self.sandbox.type_text(clean_text)
476
- return f"Typed text: '{clean_text}'"
477
-
478
- @tool
479
- def press_key(key: str) -> str:
480
- """
481
- Presses a keyboard key (e.g., 'enter', 'backspace', 'esc').
482
- Args:
483
- key: The key name.
484
- """
485
- self.sandbox.press_key(key)
486
- return f"Pressed key: {key}"
487
-
488
- @tool
489
- def drag_and_drop(x1: int, y1: int, x2: int, y2: int) -> str:
490
- """
491
- Drags from (x1, y1) and drops at (x2, y2).
492
- Args:
493
- x1: Start x coordinate.
494
- y1: Start y coordinate.
495
- x2: End x coordinate.
496
- y2: End y coordinate.
497
- """
498
- self.sandbox.drag_and_drop(x1, y1, x2, y2)
499
- return f"Dragged from [{x1}, {y1}] to [{x2}, {y2}]"
500
-
501
- @tool
502
- def scroll(amount: int, direction: str = "down") -> str:
503
- """
504
- Scrolls the page.
505
- Args:
506
- amount: The amount to scroll (1-10).
507
- direction: "up" or "down".
508
- """
509
- self.sandbox.scroll(amount, direction)
510
- return f"Scrolled {direction} by {amount}"
511
-
512
- @tool
513
- def wait(seconds: float) -> str:
514
- """
515
- Waits for the specified number of seconds.
516
- Args:
517
- seconds: The duration to wait.
518
- """
519
- time.sleep(seconds)
520
- return f"Waited for {seconds} seconds"
521
-
522
- @tool
523
- def open_url(url: str) -> str:
524
- """
525
- Navigates the browser to the specified URL.
526
- Args:
527
- url: The URL to open.
528
- """
529
- if not url.startswith(("http://", "https://")):
530
- url = "https://" + url
531
- try:
532
- self.sandbox.driver.get(url)
533
- time.sleep(2)
534
- title = self.sandbox.driver.title
535
- return f"Opened URL: {url}. Page Title: {title}"
536
- except Exception as e:
537
- return f"Failed to open URL: {e}"
538
-
539
- @tool
540
- def go_back() -> str:
541
- """
542
- Goes back to the previous page in history.
543
- """
544
- self.sandbox.driver.back()
545
- return "Went back one page"
546
-
547
- return [click, right_click, double_click, type_text, press_key, drag_and_drop, scroll, wait, open_url, go_back]
548
-
549
-
550
- def take_screenshot_callback(self, memory_step: ActionStep, agent=None) -> None:
551
- """Takes a screenshot and saves it to memory"""
552
- current_step = memory_step.step_number
553
- time.sleep(1.0) # Wait for renders
554
-
555
- image = self.sandbox.get_screenshot()
556
 
557
- # Save to disk
558
- screenshot_path = os.path.join(self.data_dir, f"step_{current_step:03d}.png")
559
- image.save(screenshot_path)
560
-
561
- image_copy = image.copy()
562
- if getattr(self, "click_coordinates", None):
563
- image_copy = draw_marker_on_image(image_copy, self.click_coordinates)
564
 
565
- self.last_marked_screenshot = AgentImage(screenshot_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
- # Cleanup old images in memory to save RAM
568
- for previous_memory_step in agent.memory.steps:
569
- if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= current_step - 1:
570
- previous_memory_step.observations_images = None
571
- elif isinstance(previous_memory_step, TaskStep):
572
- previous_memory_step.task_images = None
573
-
574
- memory_step.observations_images = [image_copy]
575
- self.click_coordinates = None
576
-
577
-
578
- def create_agent(data_dir, sandbox):
579
- model = FaraLocalModel()
580
- return SeleniumVisionAgent(
581
- model=model,
582
- data_dir=data_dir,
583
- sandbox=sandbox,
584
- max_steps=30,
585
- verbosity_level=2
586
- )
587
 
588
- def generate_interaction_id(session_uuid):
589
- return f"{session_uuid}_{int(time.time())}"
590
-
591
- def get_agent_summary_erase_images(agent):
592
- for memory_step in agent.memory.steps:
593
- if hasattr(memory_step, "observations_images"):
594
- memory_step.observations_images = None
595
- if hasattr(memory_step, "task_images"):
596
- memory_step.task_images = None
597
- return agent.write_memory_to_messages()
598
-
599
- def save_final_status(folder, status: str, summary, error_message=None) -> None:
600
- try:
601
- with open(os.path.join(folder, "metadata.json"), "w") as output_file:
602
- output_file.write(
603
- json.dumps(
604
- {"status": status, "summary": summary, "error_message": error_message},
605
- default=str
606
- )
607
- )
608
- except Exception as e:
609
- print(f"Failed to save metadata: {e}")
610
 
611
  # -----------------------------------------------------------------------------
612
- # UI & APP
613
  # -----------------------------------------------------------------------------
614
 
615
- custom_css = """
616
- .modal-container { margin: var(--size-16) auto!important; }
617
- .browser-container { position: relative; width: 100%; height: 600px; border: 1px solid #444; background: #222; display: flex; align-items: center; justify-content: center; overflow: hidden; }
618
- .browser-image { max-width: 100%; max-height: 100%; object-fit: contain; }
619
- #chatbot { height: 800px!important; }
620
- """
621
-
622
- class EnrichedGradioUI(GradioUI):
623
- def interact_with_agent(
624
- self,
625
- task_input,
626
- stored_messages,
627
- session_state,
628
- session_uuid,
629
- consent_storage,
630
- request: gr.Request,
631
- ):
632
- interaction_id = generate_interaction_id(session_uuid)
633
- data_dir = os.path.join(TMP_DIR, interaction_id)
634
-
635
- sandbox = SeleniumSandbox(width=WIDTH, height=HEIGHT)
636
- agent = create_agent(data_dir=data_dir, sandbox=sandbox)
637
- session_state["agent"] = agent
638
-
639
  try:
640
- stored_messages.append(gr.ChatMessage(role="user", content=task_input))
641
- yield stored_messages, None
642
-
643
- screenshot = sandbox.get_screenshot()
644
 
645
- for msg in stream_to_gradio(
646
- agent,
647
- task=task_input,
648
- task_images=[screenshot],
649
- reset_agent_memory=False,
650
- ):
651
- if hasattr(agent, "last_marked_screenshot") and msg.content == "-----":
652
- stored_messages.append(
653
- gr.ChatMessage(
654
- role="assistant",
655
- content={
656
- "path": agent.last_marked_screenshot.to_string(),
657
- "mime_type": "image/png",
658
- },
659
- )
660
- )
661
- yield stored_messages, agent.last_marked_screenshot.to_string()
662
- else:
663
- stored_messages.append(msg)
664
- yield stored_messages, None
665
-
666
- if consent_storage:
667
- summary = get_agent_summary_erase_images(agent)
668
- save_final_status(data_dir, "completed", summary=summary)
669
 
670
- yield stored_messages, None
671
-
672
  except Exception as e:
673
- error_message = f"Error in interaction: {str(e)}"
674
- print(error_message)
675
- stored_messages.append(
676
- gr.ChatMessage(role="assistant", content="Run failed:\n" + error_message)
677
- )
678
- yield stored_messages, None
679
- finally:
680
- sandbox.cleanup()
681
 
682
- theme = gr.themes.Default(
683
- font=["Oxanium", "sans-serif"], primary_hue="amber", secondary_hue="blue"
684
- )
 
685
 
686
- with gr.Blocks(theme=theme, css=custom_css) as demo:
687
- session_uuid_state = gr.State(lambda: str(uuid.uuid4()))
688
- session_state = gr.State({})
689
- stored_messages = gr.State([])
 
690
 
691
  with gr.Row():
692
  with gr.Column(scale=1):
693
- gr.Markdown("### Fara CUA - Chrome Agent 🌐")
 
 
694
 
695
- task_input = gr.Textbox(
696
- value="Go to google.com and search for 'Hugging Face'",
697
- label="Task",
698
- lines=3
699
- )
700
- run_btn = gr.Button("Start Task", variant="primary")
701
- stop_btn = gr.Button("Stop", variant="secondary")
702
- consent_storage = gr.Checkbox(label="Save logs locally?", value=True)
703
 
704
- gr.Examples(
705
- examples=[
706
- "Go to google.com and search for 'Hugging Face', then click the first link.",
707
- "Go to wikipedia.org, type 'Python' in search, and click the search button.",
708
- ],
709
- inputs=task_input
710
- )
711
 
712
- with gr.Column(scale=3):
713
- with gr.Row():
714
- with gr.Column(scale=1):
715
- chatbot_display = gr.Chatbot(
716
- label="Agent Trace",
717
- type="messages",
718
- height=800,
719
- avatar_images=(None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"),
720
- )
721
- with gr.Column(scale=1):
722
- gr.Markdown("### Latest Browser View")
723
- live_browser_view = gr.Image(
724
- label="Browser View",
725
- type="filepath",
726
- interactive=False,
727
- height=600
728
- )
729
-
730
- agent_ui = EnrichedGradioUI(CodeAgent(tools=[], model=Model(), name="init"))
731
-
732
- def interrupt_agent(session_state):
733
- if "agent" in session_state and hasattr(session_state["agent"], "interrupt_switch"):
734
- session_state["agent"].interrupt_switch = True
735
- return "Interrupted"
736
-
737
- run_event = run_btn.click(
738
- fn=agent_ui.interact_with_agent,
739
- inputs=[
740
- task_input,
741
- stored_messages,
742
- session_state,
743
- session_uuid_state,
744
- consent_storage,
745
- ],
746
- outputs=[chatbot_display, live_browser_view]
747
  )
748
 
749
- stop_btn.click(fn=interrupt_agent, inputs=[session_state], outputs=[])
 
 
 
 
 
 
 
750
 
751
  if __name__ == "__main__":
752
- demo.launch(share=True)
 
 
1
  import os
2
+ import re
3
  import time
4
+ import shutil
5
  import uuid
6
  import tempfile
 
7
  import unicodedata
8
  from io import BytesIO
9
+ from typing import Tuple, Optional, List, Dict, Any
 
 
10
 
11
  import gradio as gr
12
+ import numpy as np
13
  import torch
14
  import spaces
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ # Transformers imports
18
+ from transformers import (
19
+ Qwen2_5_VLForConditionalGeneration,
20
+ AutoProcessor,
21
+ )
22
+ from qwen_vl_utils import process_vision_info
23
 
24
  # Selenium Imports
25
  from selenium import webdriver
 
30
  from selenium.webdriver.common.keys import Keys
31
  from webdriver_manager.chrome import ChromeDriverManager
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # -----------------------------------------------------------------------------
34
+ # CONSTANTS & CONFIG
35
  # -----------------------------------------------------------------------------
36
 
37
+ MODEL_ID = "microsoft/Fara-7B" # Or your specific Fara model repo
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
39
  WIDTH = 1024
40
  HEIGHT = 768
41
+ TMP_DIR = "./tmp"
42
  if not os.path.exists(TMP_DIR):
43
  os.makedirs(TMP_DIR)
44
 
45
+ # System Prompt adapted for Fara/GUI agents
46
+ OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status.
47
+ You need to generate the next action to complete the task.
48
+
49
+ Supported actions:
50
+ 1. `click(x=0.5, y=0.5)`: Click at the specific location.
51
+ 2. `right_click(x=0.5, y=0.5)`: Right click at the specific location.
52
+ 3. `double_click(x=0.5, y=0.5)`: Double click at the specific location.
53
+ 4. `type_text(text="hello")`: Type the text.
54
+ 5. `scroll(amount=2, direction="down")`: Scroll the page.
55
+ 6. `press_key(key="enter")`: Press a specific key.
56
+ 7. `open_url(url="https://google.com")`: Open a specific URL.
57
+
58
+ Output format:
59
+ Please wrap the action code in <code> </code> tags.
60
+ Example:
61
+ <code>
62
+ click(x=0.23, y=0.45)
63
+ </code>
64
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # -----------------------------------------------------------------------------
67
+ # MODEL WRAPPER (Replacing smolagents)
68
  # -----------------------------------------------------------------------------
69
 
70
+ class FaraModelWrapper:
71
+ def __init__(self, model_id: str, to_device: str = "cuda"):
72
+ print(f"Loading {model_id} on {to_device}...")
73
+ self.model_id = model_id
74
+
75
+ try:
76
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
77
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
+ model_id,
79
+ trust_remote_code=True,
80
+ torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
81
+ device_map="auto" if to_device == "cuda" else None,
82
+ )
83
+ if to_device == "cpu":
84
+ self.model.to("cpu")
85
+ self.model.eval()
86
+ print("Model loaded successfully.")
87
+ except Exception as e:
88
+ print(f"Failed to load Fara, falling back to Qwen2.5-VL-7B for demo compatibility. Error: {e}")
89
+ fallback_id = "Qwen/Qwen2.5-VL-7B-Instruct"
90
+ self.processor = AutoProcessor.from_pretrained(fallback_id, trust_remote_code=True)
91
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
92
+ fallback_id,
93
+ trust_remote_code=True,
94
+ torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32,
95
+ device_map="auto",
96
+ )
97
 
98
+ def generate(self, messages: list[dict], max_new_tokens=512):
99
+ # Prepare inputs for Fara/QwenVL
100
+ text = self.processor.apply_chat_template(
101
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ image_inputs, video_inputs = process_vision_info(messages)
105
 
106
+ inputs = self.processor(
107
+ text=[text],
108
+ images=image_inputs,
109
+ videos=video_inputs,
110
+ padding=True,
111
+ return_tensors="pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
+ inputs = inputs.to(self.model.device)
114
 
115
+ with torch.no_grad():
116
+ generated_ids = self.model.generate(
117
+ **inputs,
118
+ max_new_tokens=max_new_tokens
119
+ )
120
+
121
+ # Trim input tokens
122
+ generated_ids_trimmed = [
123
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
124
+ ]
125
+
126
+ output_text = self.processor.batch_decode(
127
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
128
+ )[0]
129
+
130
+ return output_text
131
+
132
+ # Initialize global model
133
+ model = FaraModelWrapper(MODEL_ID, DEVICE)
134
 
135
  # -----------------------------------------------------------------------------
136
+ # SELENIUM SANDBOX
137
  # -----------------------------------------------------------------------------
138
 
139
  def get_system_chrome_path():
140
+ paths = ["/usr/bin/chromium", "/usr/bin/chromium-browser", "/usr/bin/google-chrome"]
 
 
 
 
 
141
  for p in paths:
142
+ if os.path.exists(p): return p
 
143
  return None
144
 
145
  class SeleniumSandbox:
 
148
  self.height = height
149
  self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_")
150
 
 
151
  chrome_opts = ChromeOptions()
 
 
152
  binary_path = get_system_chrome_path()
153
+ if binary_path: chrome_opts.binary_location = binary_path
 
 
154
 
155
  chrome_opts.add_argument("--headless=new")
156
  chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}")
157
  chrome_opts.add_argument(f"--window-size={width},{height}")
158
+ chrome_opts.add_argument("--no-sandbox")
159
+ chrome_opts.add_argument("--disable-dev-shm-usage")
160
  chrome_opts.add_argument("--disable-gpu")
 
161
 
 
162
  try:
 
163
  system_driver_path = "/usr/bin/chromedriver"
164
  if os.path.exists(system_driver_path):
 
165
  service = ChromeService(executable_path=system_driver_path)
166
  else:
 
167
  service = ChromeService(ChromeDriverManager().install())
168
 
169
  self.driver = webdriver.Chrome(service=service, options=chrome_opts)
170
  self.driver.set_window_size(width, height)
171
+ print("Selenium started.")
 
 
172
  except Exception as e:
173
+ print(f"Selenium init failed: {e}")
174
+ shutil.rmtree(self.tmp_dir, ignore_errors=True)
175
  raise e
176
 
177
  def get_screenshot(self):
178
+ return Image.open(BytesIO(self.driver.get_screenshot_as_png()))
 
 
179
 
180
+ def execute_action(self, action_data: dict):
181
+ """Execute parsed action on the browser"""
182
+ action_type = action_data.get('type')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
184
  try:
 
 
 
 
 
 
 
185
  actions = ActionChains(self.driver)
186
+ body = self.driver.find_element(By.TAG_NAME, "body")
187
+
188
+ # Helper to move to coordinates
189
+ def move_to(x_norm, y_norm):
190
+ # Convert normalized (0-1) to pixel coordinates
191
+ x_px = int(x_norm * self.width)
192
+ y_px = int(y_norm * self.height)
193
+ actions.move_to_element_with_offset(body, 0, 0)
194
+ actions.move_by_offset(x_px, y_px)
195
+
196
+ if action_type in ['click', 'right_click', 'double_click']:
197
+ move_to(action_data['x'], action_data['y'])
198
+ if action_type == 'click': actions.click()
199
+ elif action_type == 'right_click': actions.context_click()
200
+ elif action_type == 'double_click': actions.double_click()
201
+ actions.perform()
202
+
203
+ elif action_type == 'type_text':
204
+ text = action_data.get('text', '')
205
+ actions.send_keys(text)
206
+ actions.perform()
207
+
208
+ elif action_type == 'press_key':
209
+ key_name = action_data.get('key', '').lower()
210
+ k = getattr(Keys, key_name.upper(), None)
211
+ if not k:
212
+ if key_name == "enter": k = Keys.ENTER
213
+ elif key_name == "space": k = Keys.SPACE
214
+ elif key_name == "backspace": k = Keys.BACK_SPACE
215
+ if k:
216
+ actions.send_keys(k)
217
+ actions.perform()
218
+
219
+ elif action_type == 'scroll':
220
+ amount = action_data.get('amount', 2)
221
+ direction = action_data.get('direction', 'down')
222
+ scroll_y = amount * 100
223
+ if direction == 'up': scroll_y = -scroll_y
224
+ self.driver.execute_script(f"window.scrollBy(0, {scroll_y});")
225
+
226
+ elif action_type == 'open_url':
227
+ url = action_data.get('url', '')
228
+ if not url.startswith('http'): url = 'https://' + url
229
+ self.driver.get(url)
230
+ time.sleep(2)
231
+
232
+ return f"Executed {action_type}"
233
  except Exception as e:
234
+ return f"Action failed: {e}"
235
 
236
  def cleanup(self):
237
+ try: self.driver.quit()
238
+ except: pass
 
 
 
239
  shutil.rmtree(self.tmp_dir, ignore_errors=True)
240
 
241
  # -----------------------------------------------------------------------------
242
+ # PARSING LOGIC
243
  # -----------------------------------------------------------------------------
244
 
245
+ def parse_code_block(response: str) -> str:
246
+ pattern = r"<code>\s*(.*?)\s*</code>"
247
+ matches = re.findall(pattern, response, re.DOTALL)
248
+ if matches:
249
+ return matches[-1].strip() # Return the last code block
250
+ return ""
251
+
252
+ def parse_action_string(action_str: str) -> dict:
253
+ """Parse string like 'click(x=0.5, y=0.5)' into a dict"""
254
+ # Simple regex parsing for demonstration
255
+ action_data = {}
256
+
257
+ # 1. Coordinate actions: name(x=..., y=...)
258
+ coord_match = re.match(r"(\w+)\s*\(\s*x\s*=\s*([0-9.]+)\s*,\s*y\s*=\s*([0-9.]+)\s*\)", action_str)
259
+ if coord_match:
260
+ return {
261
+ "type": coord_match.group(1),
262
+ "x": float(coord_match.group(2)),
263
+ "y": float(coord_match.group(3))
264
+ }
265
+
266
+ # 2. Open URL: open_url(url="...")
267
+ url_match = re.match(r"open_url\s*\(\s*url\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
268
+ if url_match:
269
+ return {"type": "open_url", "url": url_match.group(1)}
270
+
271
+ # 3. Type text: type_text(text="...")
272
+ text_match = re.match(r"type_text\s*\(\s*text\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
273
+ if text_match:
274
+ return {"type": "type_text", "text": text_match.group(1)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ # 4. Press key: press_key(key="...")
277
+ key_match = re.match(r"press_key\s*\(\s*key\s*=\s*[\"'](.*?)[\"']\s*\)", action_str)
278
+ if key_match:
279
+ return {"type": "press_key", "key": key_match.group(1)}
280
+
281
+ # 5. Scroll: scroll(amount=..., direction="...")
282
+ if "scroll" in action_str:
283
+ return {"type": "scroll", "amount": 2, "direction": "down"} # Default
284
 
285
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ # -----------------------------------------------------------------------------
288
+ # MAIN LOOP
289
+ # -----------------------------------------------------------------------------
 
 
 
 
290
 
291
+ @spaces.GPU(duration=120)
292
+ def agent_step(task_instruction: str, history: list, sandbox_state: dict):
293
+ # Initialize sandbox if needed (handled via state in Gradio mostly, but for safety)
294
+ if 'uuid' not in sandbox_state:
295
+ sandbox_state['uuid'] = str(uuid.uuid4())
296
+ sandbox = SeleniumSandbox(WIDTH, HEIGHT)
297
+ # Store sandbox instance reference globally or handle cleanup carefully
298
+ # For this demo, we'll recreate/attach to session based on state if persisting,
299
+ # but here we'll assume a persistent session for the run.
300
+
301
+ # HACK: For Gradio state persistence with objects that can't be pickled easily,
302
+ # we often use a global dict mapping UUID -> Sandbox
303
+ sandbox_id = sandbox_state['uuid']
304
+ if sandbox_id not in SANDBOX_REGISTRY:
305
+ SANDBOX_REGISTRY[sandbox_id] = SeleniumSandbox(WIDTH, HEIGHT)
306
+
307
+ sandbox = SANDBOX_REGISTRY[sandbox_id]
308
+
309
+ # 1. Get Screenshot
310
+ screenshot = sandbox.get_screenshot()
311
+
312
+ # 2. Construct Prompt
313
+ # Convert history text to string context if needed
314
+
315
+ messages = [
316
+ {
317
+ "role": "system",
318
+ "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]
319
+ },
320
+ {
321
+ "role": "user",
322
+ "content": [
323
+ {"type": "image", "image": screenshot},
324
+ {"type": "text", "text": f"Instruction: {task_instruction}\nPrevious Actions: {history[-1] if history else 'None'}"}
325
+ ]
326
+ }
327
+ ]
328
+
329
+ # 3. Model Inference
330
+ response = model.generate(messages)
331
+
332
+ # 4. Parse Action
333
+ action_code = parse_code_block(response)
334
+ action_data = parse_action_string(action_code)
335
+
336
+ log_entry = f"Step: {len(history)+1}\nModel Thought: {response}\nAction: {action_code}"
337
+
338
+ # 5. Execute Action
339
+ execution_result = "No valid action found"
340
+ if action_data:
341
+ execution_result = sandbox.execute_action(action_data)
342
 
343
+ # Draw marker if coordinate action
344
+ if 'x' in action_data:
345
+ draw = ImageDraw.Draw(screenshot)
346
+ x_px = action_data['x'] * WIDTH
347
+ y_px = action_data['y'] * HEIGHT
348
+ r = 10
349
+ draw.ellipse((x_px-r, y_px-r, x_px+r, y_px+r), outline="red", width=3)
350
+
351
+ log_entry += f"\nResult: {execution_result}"
352
+ history.append(log_entry)
353
+
354
+ # Return updated screenshot and history
355
+ return screenshot, history, sandbox_state
 
 
 
 
 
 
 
356
 
357
+ # Global registry for sandboxes
358
+ SANDBOX_REGISTRY = {}
359
+
360
+ def cleanup_sandbox(sandbox_state):
361
+ sid = sandbox_state.get('uuid')
362
+ if sid and sid in SANDBOX_REGISTRY:
363
+ SANDBOX_REGISTRY[sid].cleanup()
364
+ del SANDBOX_REGISTRY[sid]
365
+ return [], {}
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  # -----------------------------------------------------------------------------
368
+ # GRADIO UI
369
  # -----------------------------------------------------------------------------
370
 
371
+ def run_task_loop(task, history, state):
372
+ # This generator function runs the agent loop
373
+ max_steps = 10
374
+
375
+ for i in range(max_steps):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  try:
377
+ # Run one step
378
+ screenshot, new_history, new_state = agent_step(task, history, state)
379
+ history = new_history
 
380
 
381
+ # Yield updates to UI
382
+ # We yield the logs (joined) and the latest image
383
+ logs_text = "\n\n" + "-"*40 + "\n\n".join(history)
384
+ yield screenshot, logs_text, state
385
+
386
+ # Check for termination (simplistic)
387
+ if "Done" in history[-1] or "finished" in history[-1].lower():
388
+ break
389
+
390
+ time.sleep(1) # Pause for visual effect
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
 
 
392
  except Exception as e:
393
+ error_msg = f"Error in loop: {e}"
394
+ history.append(error_msg)
395
+ yield None, "\n".join(history), state
396
+ break
 
 
 
 
397
 
398
+ # UI Layout
399
+ custom_css = """
400
+ #view_img { height: 600px; object-fit: contain; }
401
+ """
402
 
403
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
404
+ state = gr.State({})
405
+ history = gr.State([])
406
+
407
+ gr.Markdown("# 🤖 Fara CUA - Chrome Agent")
408
 
409
  with gr.Row():
410
  with gr.Column(scale=1):
411
+ task_input = gr.Textbox(label="Task Instruction", value="Go to google.com and search for 'SpaceX'")
412
+ run_btn = gr.Button("Run Agent", variant="primary")
413
+ clear_btn = gr.Button("Reset / Clear")
414
 
415
+ with gr.Column(scale=2):
416
+ browser_view = gr.Image(label="Live Browser View", elem_id="view_img", interactive=False)
 
 
 
 
 
 
417
 
418
+ logs_output = gr.Textbox(label="Agent Logs", lines=15, interactive=False)
 
 
 
 
 
 
419
 
420
+ # Event handlers
421
+ run_btn.click(
422
+ fn=run_task_loop,
423
+ inputs=[task_input, history, state],
424
+ outputs=[browser_view, logs_output, state]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  )
426
 
427
+ clear_btn.click(
428
+ fn=cleanup_sandbox,
429
+ inputs=[state],
430
+ outputs=[history, state]
431
+ ).then(
432
+ lambda: (None, ""),
433
+ outputs=[browser_view, logs_output]
434
+ )
435
 
436
  if __name__ == "__main__":
437
+ demo.launch()