import os import re import time import shutil import uuid import json import tempfile from io import BytesIO import threading import gradio as gr import torch import spaces from PIL import Image, ImageDraw # Transformers imports from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, ) from qwen_vl_utils import process_vision_info # Selenium Imports from selenium import webdriver from selenium.webdriver.chrome.service import Service as ChromeService from selenium.webdriver.chrome.options import Options as ChromeOptions from selenium.webdriver.common.action_chains import ActionChains from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys from webdriver_manager.chrome import ChromeDriverManager # ----------------------------------------------------------------------------- # CONSTANTS & CONFIG # ----------------------------------------------------------------------------- MODEL_ID = "microsoft/Fara-7B" # Use the Qwen fallback if Fara isn't directly accessible in your environment FALLBACK_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" WIDTH = 1024 HEIGHT = 768 TMP_DIR = "./tmp" if not os.path.exists(TMP_DIR): os.makedirs(TMP_DIR) # Updated System Prompt to match the JSON tool_call format the model prefers OS_SYSTEM_PROMPT = """You are a helpful GUI agent controlling a Chrome browser. You will be given a screenshot of the current page and a high-level task. You need to generate the next action to move towards completing the task. The browser resolution is 1024x768. Output your action in the following XML format containing JSON: {"name": "Browser", "arguments": { ... }} Supported Actions (in 'arguments'): 1. Click: {"action": "click", "coordinate": [x, y]} (where x and y are integer coordinates based on a 1000x1000 normalized grid) 2. Type: {"action": "type_text", "text": "something", "coordinate": [x, y], "press_enter": true} (Coordinate is optional but recommended to focus the input field first) 3. Scroll: {"action": "scroll", "direction": "down"} 4. Navigate: {"action": "navigate", "url": "https://..."} Example: {"name": "Browser", "arguments": {"action": "type_text", "coordinate": [500, 280], "text": "hugging face models", "press_enter": true}} """ # ----------------------------------------------------------------------------- # MODEL WRAPPER # ----------------------------------------------------------------------------- class ModelWrapper: def __init__(self, model_id: str, to_device: str = "cuda"): print(f"Loading model: {model_id} on {to_device}...") self.device = to_device try: self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32, device_map="auto" if to_device == "cuda" else None, ) except Exception as e: print(f"Primary model load failed ({e}). Loading fallback: {FALLBACK_MODEL_ID}") self.processor = AutoProcessor.from_pretrained(FALLBACK_MODEL_ID, trust_remote_code=True) self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( FALLBACK_MODEL_ID, trust_remote_code=True, torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32, device_map="auto" if to_device == "cuda" else None, ) if to_device == "cpu": self.model.to("cpu") self.model.eval() print("Model loaded successfully.") def generate(self, messages: list[dict], max_new_tokens=512): text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(self.model.device) with torch.no_grad(): generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return output_text # Initialize Global Model model = ModelWrapper(MODEL_ID, DEVICE) # ----------------------------------------------------------------------------- # SELENIUM SANDBOX # ----------------------------------------------------------------------------- def get_system_chrome_path(): paths = ["/usr/bin/chromium", "/usr/bin/chromium-browser", "/usr/bin/google-chrome"] for p in paths: if os.path.exists(p): return p return None class SeleniumSandbox: def __init__(self, width=1024, height=768): self.width = width self.height = height self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_") chrome_opts = ChromeOptions() binary_path = get_system_chrome_path() if binary_path: chrome_opts.binary_location = binary_path chrome_opts.add_argument("--headless=new") chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}") chrome_opts.add_argument(f"--window-size={width},{height}") chrome_opts.add_argument("--no-sandbox") chrome_opts.add_argument("--disable-dev-shm-usage") chrome_opts.add_argument("--disable-gpu") try: system_driver_path = "/usr/bin/chromedriver" if os.path.exists(system_driver_path): service = ChromeService(executable_path=system_driver_path) else: service = ChromeService(ChromeDriverManager().install()) self.driver = webdriver.Chrome(service=service, options=chrome_opts) self.driver.set_window_size(width, height) # Start blank self.driver.get("about:blank") print("Selenium started.") except Exception as e: print(f"Selenium init failed: {e}") shutil.rmtree(self.tmp_dir, ignore_errors=True) raise e def get_screenshot(self): return Image.open(BytesIO(self.driver.get_screenshot_as_png())) def execute_action(self, action_data: dict): """Execute parsed JSON action on the browser""" # Mapping model's JSON structure to Selenium calls args = action_data.get("arguments", {}) action_type = args.get("action") try: actions = ActionChains(self.driver) body = self.driver.find_element(By.TAG_NAME, "body") # 1. Handle Coordinate Movement (Common to click/type) if "coordinate" in args: coords = args["coordinate"] # Assuming Fara uses 1000x1000 normalization standard x_norm = coords[0] / 1000 y_norm = coords[1] / 1000 x_px = int(x_norm * self.width) y_px = int(y_norm * self.height) # Move mouse actions.move_to_element_with_offset(body, 0, 0) actions.move_by_offset(x_px, y_px) actions.click() # Focus the element actions.perform() # Reset actions queue actions = ActionChains(self.driver) # 2. Handle Specific Actions if action_type == "navigate": url = args.get("url") if url: if not url.startswith("http"): url = "https://" + url self.driver.get(url) time.sleep(2) return f"Navigated to {url}" elif action_type == "type_text": text = args.get("text", "") actions.send_keys(text) if args.get("press_enter", False): actions.send_keys(Keys.ENTER) actions.perform() return f"Typed '{text}'" elif action_type == "click": # Click is handled in coordinate block above, just return status return f"Clicked at {args.get('coordinate')}" elif action_type == "scroll": direction = args.get("direction", "down") scroll_amount = 300 if direction == "down" else -300 self.driver.execute_script(f"window.scrollBy(0, {scroll_amount});") return f"Scrolled {direction}" return f"Executed {action_type}" except Exception as e: print(f"Execution Error: {e}") return f"Action failed: {e}" def cleanup(self): try: self.driver.quit() except: pass shutil.rmtree(self.tmp_dir, ignore_errors=True) # ----------------------------------------------------------------------------- # PARSER # ----------------------------------------------------------------------------- def parse_model_response(response: str) -> dict: """ Parses JSON content Returns a dictionary or None """ # Regex to extract JSON inside tool_call tags pattern = r"\s*({.*?})\s*" match = re.search(pattern, response, re.DOTALL) if match: try: json_str = match.group(1) data = json.loads(json_str) return data except json.JSONDecodeError: print("Failed to decode JSON from tool_call") return None return None # ----------------------------------------------------------------------------- # AGENT LOOP # ----------------------------------------------------------------------------- # Global registry to persist sessions in Gradio SANDBOX_REGISTRY = {} @spaces.GPU(duration=120) def agent_step(task_instruction: str, history: list, sandbox_state: dict): # Retrieve or create sandbox if 'uuid' not in sandbox_state: sandbox_state['uuid'] = str(uuid.uuid4()) sid = sandbox_state['uuid'] if sid not in SANDBOX_REGISTRY: SANDBOX_REGISTRY[sid] = SeleniumSandbox(WIDTH, HEIGHT) sandbox = SANDBOX_REGISTRY[sid] # 1. Capture State screenshot = sandbox.get_screenshot() # 2. Build Messages # Fara works best when seeing the history of images, but for memory efficiency # in this demo we will just send the current screenshot + text history. messages = [ {"role": "system", "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]}, { "role": "user", "content": [ {"type": "image", "image": screenshot}, {"type": "text", "text": f"Task: {task_instruction}\nPrevious Actions Log:\n" + "\n".join(history[-3:])} ] } ] # 3. Inference response = model.generate(messages) # 4. Parse & Execute action_data = parse_model_response(response) log_entry = f"Thought: {response}\n" if action_data: result = sandbox.execute_action(action_data) log_entry += f"Action: {action_data.get('arguments', {}).get('action')}\nResult: {result}" # Visualize click on screenshot for UI args = action_data.get("arguments", {}) if "coordinate" in args: draw = ImageDraw.Draw(screenshot) coords = args["coordinate"] # Map 1000x1000 back to image size x = int(coords[0] / 1000 * WIDTH) y = int(coords[1] / 1000 * HEIGHT) draw.ellipse((x-10, y-10, x+10, y+10), outline="red", width=5) else: log_entry += "Action: Parsing Failed or No Action" history.append(log_entry) return screenshot, history, sandbox_state def cleanup_sandbox(sandbox_state): sid = sandbox_state.get('uuid') if sid and sid in SANDBOX_REGISTRY: SANDBOX_REGISTRY[sid].cleanup() del SANDBOX_REGISTRY[sid] return [], {} # ----------------------------------------------------------------------------- # GRADIO UI # ----------------------------------------------------------------------------- def run_loop(task, history, state): MAX_STEPS = 10 for i in range(MAX_STEPS): try: img, new_hist, new_state = agent_step(task, history, state) history = new_hist # Combine history into a readable log log_text = "\n" + "="*40 + "\n".join(history) yield img, log_text, state time.sleep(1) # Visual pause except Exception as e: history.append(f"Critical Error: {e}") yield None, "\n".join(history), state break custom_css = """ .browser-img { height: 600px; object-fit: contain; border: 2px solid #333; } """ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: state = gr.State({}) history = gr.State([]) gr.Markdown("# 🌐 Fara CUA - Chrome Agent") gr.Markdown("Agent that uses **Microsoft Fara-7B** (Vision) to control a headless Chrome browser.") with gr.Row(): with gr.Column(scale=1): task_input = gr.Textbox( label="Task", value="Go to google.com and search for 'Hugging Face models'", lines=2 ) with gr.Row(): run_btn = gr.Button("▶ Run Agent", variant="primary") reset_btn = gr.Button("⏹ Reset") gr.Examples([ "Go to google.com and search for 'Hugging Face models'", "Navigate to wikipedia.org, type 'Artificial Intelligence' and press enter", "Go to bing.com and search for 'SpaceX launch'" ], inputs=task_input) with gr.Column(scale=2): browser_view = gr.Image( label="Live Browser View", interactive=False, elem_classes="browser-img", type="pil" ) logs_out = gr.Textbox(label="Execution Logs", lines=10, interactive=False) run_btn.click( fn=run_loop, inputs=[task_input, history, state], outputs=[browser_view, logs_out, state] ) reset_btn.click( fn=cleanup_sandbox, inputs=[state], outputs=[history, state] ).then( lambda: (None, ""), outputs=[browser_view, logs_out] ) if __name__ == "__main__": demo.launch(share=True)