Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import time | |
| import shutil | |
| import uuid | |
| import json | |
| import tempfile | |
| from io import BytesIO | |
| import threading | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from PIL import Image, ImageDraw | |
| # Transformers imports | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoProcessor, | |
| ) | |
| from qwen_vl_utils import process_vision_info | |
| # Selenium Imports | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.service import Service as ChromeService | |
| from selenium.webdriver.chrome.options import Options as ChromeOptions | |
| from selenium.webdriver.common.action_chains import ActionChains | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.common.keys import Keys | |
| from webdriver_manager.chrome import ChromeDriverManager | |
| # ----------------------------------------------------------------------------- | |
| # CONSTANTS & CONFIG | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "microsoft/Fara-7B" | |
| # Use the Qwen fallback if Fara isn't directly accessible in your environment | |
| FALLBACK_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| WIDTH = 1024 | |
| HEIGHT = 768 | |
| TMP_DIR = "./tmp" | |
| if not os.path.exists(TMP_DIR): | |
| os.makedirs(TMP_DIR) | |
| # Updated System Prompt to match the JSON tool_call format the model prefers | |
| OS_SYSTEM_PROMPT = """You are a helpful GUI agent controlling a Chrome browser. | |
| You will be given a screenshot of the current page and a high-level task. | |
| You need to generate the next action to move towards completing the task. | |
| The browser resolution is 1024x768. | |
| Output your action in the following XML format containing JSON: | |
| <tool_call> | |
| {"name": "Browser", "arguments": { ... }} | |
| </tool_call> | |
| Supported Actions (in 'arguments'): | |
| 1. Click: {"action": "click", "coordinate": [x, y]} | |
| (where x and y are integer coordinates based on a 1000x1000 normalized grid) | |
| 2. Type: {"action": "type_text", "text": "something", "coordinate": [x, y], "press_enter": true} | |
| (Coordinate is optional but recommended to focus the input field first) | |
| 3. Scroll: {"action": "scroll", "direction": "down"} | |
| 4. Navigate: {"action": "navigate", "url": "https://..."} | |
| Example: | |
| <tool_call> | |
| {"name": "Browser", "arguments": {"action": "type_text", "coordinate": [500, 280], "text": "hugging face models", "press_enter": true}} | |
| </tool_call> | |
| """ | |
| # ----------------------------------------------------------------------------- | |
| # MODEL WRAPPER | |
| # ----------------------------------------------------------------------------- | |
| class ModelWrapper: | |
| def __init__(self, model_id: str, to_device: str = "cuda"): | |
| print(f"Loading model: {model_id} on {to_device}...") | |
| self.device = to_device | |
| try: | |
| self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32, | |
| device_map="auto" if to_device == "cuda" else None, | |
| ) | |
| except Exception as e: | |
| print(f"Primary model load failed ({e}). Loading fallback: {FALLBACK_MODEL_ID}") | |
| self.processor = AutoProcessor.from_pretrained(FALLBACK_MODEL_ID, trust_remote_code=True) | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| FALLBACK_MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32, | |
| device_map="auto" if to_device == "cuda" else None, | |
| ) | |
| if to_device == "cpu": | |
| self.model.to("cpu") | |
| self.model.eval() | |
| print("Model loaded successfully.") | |
| def generate(self, messages: list[dict], max_new_tokens=512): | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(self.model.device) | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| return output_text | |
| # Initialize Global Model | |
| model = ModelWrapper(MODEL_ID, DEVICE) | |
| # ----------------------------------------------------------------------------- | |
| # SELENIUM SANDBOX | |
| # ----------------------------------------------------------------------------- | |
| def get_system_chrome_path(): | |
| paths = ["/usr/bin/chromium", "/usr/bin/chromium-browser", "/usr/bin/google-chrome"] | |
| for p in paths: | |
| if os.path.exists(p): return p | |
| return None | |
| class SeleniumSandbox: | |
| def __init__(self, width=1024, height=768): | |
| self.width = width | |
| self.height = height | |
| self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_") | |
| chrome_opts = ChromeOptions() | |
| binary_path = get_system_chrome_path() | |
| if binary_path: chrome_opts.binary_location = binary_path | |
| chrome_opts.add_argument("--headless=new") | |
| chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}") | |
| chrome_opts.add_argument(f"--window-size={width},{height}") | |
| chrome_opts.add_argument("--no-sandbox") | |
| chrome_opts.add_argument("--disable-dev-shm-usage") | |
| chrome_opts.add_argument("--disable-gpu") | |
| try: | |
| system_driver_path = "/usr/bin/chromedriver" | |
| if os.path.exists(system_driver_path): | |
| service = ChromeService(executable_path=system_driver_path) | |
| else: | |
| service = ChromeService(ChromeDriverManager().install()) | |
| self.driver = webdriver.Chrome(service=service, options=chrome_opts) | |
| self.driver.set_window_size(width, height) | |
| # Start blank | |
| self.driver.get("about:blank") | |
| print("Selenium started.") | |
| except Exception as e: | |
| print(f"Selenium init failed: {e}") | |
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |
| raise e | |
| def get_screenshot(self): | |
| return Image.open(BytesIO(self.driver.get_screenshot_as_png())) | |
| def execute_action(self, action_data: dict): | |
| """Execute parsed JSON action on the browser""" | |
| # Mapping model's JSON structure to Selenium calls | |
| args = action_data.get("arguments", {}) | |
| action_type = args.get("action") | |
| try: | |
| actions = ActionChains(self.driver) | |
| body = self.driver.find_element(By.TAG_NAME, "body") | |
| # 1. Handle Coordinate Movement (Common to click/type) | |
| if "coordinate" in args: | |
| coords = args["coordinate"] | |
| # Assuming Fara uses 1000x1000 normalization standard | |
| x_norm = coords[0] / 1000 | |
| y_norm = coords[1] / 1000 | |
| x_px = int(x_norm * self.width) | |
| y_px = int(y_norm * self.height) | |
| # Move mouse | |
| actions.move_to_element_with_offset(body, 0, 0) | |
| actions.move_by_offset(x_px, y_px) | |
| actions.click() # Focus the element | |
| actions.perform() | |
| # Reset actions queue | |
| actions = ActionChains(self.driver) | |
| # 2. Handle Specific Actions | |
| if action_type == "navigate": | |
| url = args.get("url") | |
| if url: | |
| if not url.startswith("http"): url = "https://" + url | |
| self.driver.get(url) | |
| time.sleep(2) | |
| return f"Navigated to {url}" | |
| elif action_type == "type_text": | |
| text = args.get("text", "") | |
| actions.send_keys(text) | |
| if args.get("press_enter", False): | |
| actions.send_keys(Keys.ENTER) | |
| actions.perform() | |
| return f"Typed '{text}'" | |
| elif action_type == "click": | |
| # Click is handled in coordinate block above, just return status | |
| return f"Clicked at {args.get('coordinate')}" | |
| elif action_type == "scroll": | |
| direction = args.get("direction", "down") | |
| scroll_amount = 300 if direction == "down" else -300 | |
| self.driver.execute_script(f"window.scrollBy(0, {scroll_amount});") | |
| return f"Scrolled {direction}" | |
| return f"Executed {action_type}" | |
| except Exception as e: | |
| print(f"Execution Error: {e}") | |
| return f"Action failed: {e}" | |
| def cleanup(self): | |
| try: self.driver.quit() | |
| except: pass | |
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |
| # ----------------------------------------------------------------------------- | |
| # PARSER | |
| # ----------------------------------------------------------------------------- | |
| def parse_model_response(response: str) -> dict: | |
| """ | |
| Parses <tool_call> JSON content </tool_call> | |
| Returns a dictionary or None | |
| """ | |
| # Regex to extract JSON inside tool_call tags | |
| pattern = r"<tool_call>\s*({.*?})\s*</tool_call>" | |
| match = re.search(pattern, response, re.DOTALL) | |
| if match: | |
| try: | |
| json_str = match.group(1) | |
| data = json.loads(json_str) | |
| return data | |
| except json.JSONDecodeError: | |
| print("Failed to decode JSON from tool_call") | |
| return None | |
| return None | |
| # ----------------------------------------------------------------------------- | |
| # AGENT LOOP | |
| # ----------------------------------------------------------------------------- | |
| # Global registry to persist sessions in Gradio | |
| SANDBOX_REGISTRY = {} | |
| def agent_step(task_instruction: str, history: list, sandbox_state: dict): | |
| # Retrieve or create sandbox | |
| if 'uuid' not in sandbox_state: | |
| sandbox_state['uuid'] = str(uuid.uuid4()) | |
| sid = sandbox_state['uuid'] | |
| if sid not in SANDBOX_REGISTRY: | |
| SANDBOX_REGISTRY[sid] = SeleniumSandbox(WIDTH, HEIGHT) | |
| sandbox = SANDBOX_REGISTRY[sid] | |
| # 1. Capture State | |
| screenshot = sandbox.get_screenshot() | |
| # 2. Build Messages | |
| # Fara works best when seeing the history of images, but for memory efficiency | |
| # in this demo we will just send the current screenshot + text history. | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": screenshot}, | |
| {"type": "text", "text": f"Task: {task_instruction}\nPrevious Actions Log:\n" + "\n".join(history[-3:])} | |
| ] | |
| } | |
| ] | |
| # 3. Inference | |
| response = model.generate(messages) | |
| # 4. Parse & Execute | |
| action_data = parse_model_response(response) | |
| log_entry = f"Thought: {response}\n" | |
| if action_data: | |
| result = sandbox.execute_action(action_data) | |
| log_entry += f"Action: {action_data.get('arguments', {}).get('action')}\nResult: {result}" | |
| # Visualize click on screenshot for UI | |
| args = action_data.get("arguments", {}) | |
| if "coordinate" in args: | |
| draw = ImageDraw.Draw(screenshot) | |
| coords = args["coordinate"] | |
| # Map 1000x1000 back to image size | |
| x = int(coords[0] / 1000 * WIDTH) | |
| y = int(coords[1] / 1000 * HEIGHT) | |
| draw.ellipse((x-10, y-10, x+10, y+10), outline="red", width=5) | |
| else: | |
| log_entry += "Action: Parsing Failed or No Action" | |
| history.append(log_entry) | |
| return screenshot, history, sandbox_state | |
| def cleanup_sandbox(sandbox_state): | |
| sid = sandbox_state.get('uuid') | |
| if sid and sid in SANDBOX_REGISTRY: | |
| SANDBOX_REGISTRY[sid].cleanup() | |
| del SANDBOX_REGISTRY[sid] | |
| return [], {} | |
| # ----------------------------------------------------------------------------- | |
| # GRADIO UI | |
| # ----------------------------------------------------------------------------- | |
| def run_loop(task, history, state): | |
| MAX_STEPS = 10 | |
| for i in range(MAX_STEPS): | |
| try: | |
| img, new_hist, new_state = agent_step(task, history, state) | |
| history = new_hist | |
| # Combine history into a readable log | |
| log_text = "\n" + "="*40 + "\n".join(history) | |
| yield img, log_text, state | |
| time.sleep(1) # Visual pause | |
| except Exception as e: | |
| history.append(f"Critical Error: {e}") | |
| yield None, "\n".join(history), state | |
| break | |
| custom_css = """ | |
| .browser-img { height: 600px; object-fit: contain; border: 2px solid #333; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| state = gr.State({}) | |
| history = gr.State([]) | |
| gr.Markdown("# 🌐 Fara CUA - Chrome Agent") | |
| gr.Markdown("Agent that uses **Microsoft Fara-7B** (Vision) to control a headless Chrome browser.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_input = gr.Textbox( | |
| label="Task", | |
| value="Go to google.com and search for 'Hugging Face models'", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| run_btn = gr.Button("▶ Run Agent", variant="primary") | |
| reset_btn = gr.Button("⏹ Reset") | |
| gr.Examples([ | |
| "Go to google.com and search for 'Hugging Face models'", | |
| "Navigate to wikipedia.org, type 'Artificial Intelligence' and press enter", | |
| "Go to bing.com and search for 'SpaceX launch'" | |
| ], inputs=task_input) | |
| with gr.Column(scale=2): | |
| browser_view = gr.Image( | |
| label="Live Browser View", | |
| interactive=False, | |
| elem_classes="browser-img", | |
| type="pil" | |
| ) | |
| logs_out = gr.Textbox(label="Execution Logs", lines=10, interactive=False) | |
| run_btn.click( | |
| fn=run_loop, | |
| inputs=[task_input, history, state], | |
| outputs=[browser_view, logs_out, state] | |
| ) | |
| reset_btn.click( | |
| fn=cleanup_sandbox, | |
| inputs=[state], | |
| outputs=[history, state] | |
| ).then( | |
| lambda: (None, ""), | |
| outputs=[browser_view, logs_out] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |