Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import time | |
| import shutil | |
| import uuid | |
| import tempfile | |
| import unicodedata | |
| from io import BytesIO | |
| from typing import Tuple, Optional, List, Dict, Any | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import spaces | |
| from PIL import Image, ImageDraw, ImageFont | |
| # Transformers imports | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoProcessor, | |
| ) | |
| from qwen_vl_utils import process_vision_info | |
| # Selenium Imports | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.service import Service as ChromeService | |
| from selenium.webdriver.chrome.options import Options as ChromeOptions | |
| from selenium.webdriver.common.action_chains import ActionChains | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.common.keys import Keys | |
| from webdriver_manager.chrome import ChromeDriverManager | |
| # ----------------------------------------------------------------------------- | |
| # CONSTANTS & CONFIG | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "microsoft/Fara-7B" # Or your specific Fara model repo | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| WIDTH = 1024 | |
| HEIGHT = 768 | |
| TMP_DIR = "./tmp" | |
| if not os.path.exists(TMP_DIR): | |
| os.makedirs(TMP_DIR) | |
| # System Prompt adapted for Fara/GUI agents | |
| OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status. | |
| You need to generate the next action to complete the task. | |
| Supported actions: | |
| 1. `click(x=0.5, y=0.5)`: Click at the specific location. | |
| 2. `right_click(x=0.5, y=0.5)`: Right click at the specific location. | |
| 3. `double_click(x=0.5, y=0.5)`: Double click at the specific location. | |
| 4. `type_text(text="hello")`: Type the text. | |
| 5. `scroll(amount=2, direction="down")`: Scroll the page. | |
| 6. `press_key(key="enter")`: Press a specific key. | |
| 7. `open_url(url="https://google.com")`: Open a specific URL. | |
| Output format: | |
| Please wrap the action code in <code> </code> tags. | |
| Example: | |
| <code> | |
| click(x=0.23, y=0.45) | |
| </code> | |
| """ | |
| # ----------------------------------------------------------------------------- | |
| # MODEL WRAPPER (Replacing smolagents) | |
| # ----------------------------------------------------------------------------- | |
| class FaraModelWrapper: | |
| def __init__(self, model_id: str, to_device: str = "cuda"): | |
| print(f"Loading {model_id} on {to_device}...") | |
| self.model_id = model_id | |
| try: | |
| self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32, | |
| device_map="auto" if to_device == "cuda" else None, | |
| ) | |
| if to_device == "cpu": | |
| self.model.to("cpu") | |
| self.model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load Fara, falling back to Qwen2.5-VL-7B for demo compatibility. Error: {e}") | |
| fallback_id = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| self.processor = AutoProcessor.from_pretrained(fallback_id, trust_remote_code=True) | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| fallback_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if to_device == "cuda" else torch.float32, | |
| device_map="auto", | |
| ) | |
| def generate(self, messages: list[dict], max_new_tokens=512): | |
| # Prepare inputs for Fara/QwenVL | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(self.model.device) | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| # Trim input tokens | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| return output_text | |
| # Initialize global model | |
| model = FaraModelWrapper(MODEL_ID, DEVICE) | |
| # ----------------------------------------------------------------------------- | |
| # SELENIUM SANDBOX | |
| # ----------------------------------------------------------------------------- | |
| def get_system_chrome_path(): | |
| paths = ["/usr/bin/chromium", "/usr/bin/chromium-browser", "/usr/bin/google-chrome"] | |
| for p in paths: | |
| if os.path.exists(p): return p | |
| return None | |
| class SeleniumSandbox: | |
| def __init__(self, width=1024, height=768): | |
| self.width = width | |
| self.height = height | |
| self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_") | |
| chrome_opts = ChromeOptions() | |
| binary_path = get_system_chrome_path() | |
| if binary_path: chrome_opts.binary_location = binary_path | |
| chrome_opts.add_argument("--headless=new") | |
| chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}") | |
| chrome_opts.add_argument(f"--window-size={width},{height}") | |
| chrome_opts.add_argument("--no-sandbox") | |
| chrome_opts.add_argument("--disable-dev-shm-usage") | |
| chrome_opts.add_argument("--disable-gpu") | |
| try: | |
| system_driver_path = "/usr/bin/chromedriver" | |
| if os.path.exists(system_driver_path): | |
| service = ChromeService(executable_path=system_driver_path) | |
| else: | |
| service = ChromeService(ChromeDriverManager().install()) | |
| self.driver = webdriver.Chrome(service=service, options=chrome_opts) | |
| self.driver.set_window_size(width, height) | |
| print("Selenium started.") | |
| except Exception as e: | |
| print(f"Selenium init failed: {e}") | |
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |
| raise e | |
| def get_screenshot(self): | |
| return Image.open(BytesIO(self.driver.get_screenshot_as_png())) | |
| def execute_action(self, action_data: dict): | |
| """Execute parsed action on the browser""" | |
| action_type = action_data.get('type') | |
| try: | |
| actions = ActionChains(self.driver) | |
| body = self.driver.find_element(By.TAG_NAME, "body") | |
| # Helper to move to coordinates | |
| def move_to(x_norm, y_norm): | |
| # Convert normalized (0-1) to pixel coordinates | |
| x_px = int(x_norm * self.width) | |
| y_px = int(y_norm * self.height) | |
| actions.move_to_element_with_offset(body, 0, 0) | |
| actions.move_by_offset(x_px, y_px) | |
| if action_type in ['click', 'right_click', 'double_click']: | |
| move_to(action_data['x'], action_data['y']) | |
| if action_type == 'click': actions.click() | |
| elif action_type == 'right_click': actions.context_click() | |
| elif action_type == 'double_click': actions.double_click() | |
| actions.perform() | |
| elif action_type == 'type_text': | |
| text = action_data.get('text', '') | |
| actions.send_keys(text) | |
| actions.perform() | |
| elif action_type == 'press_key': | |
| key_name = action_data.get('key', '').lower() | |
| k = getattr(Keys, key_name.upper(), None) | |
| if not k: | |
| if key_name == "enter": k = Keys.ENTER | |
| elif key_name == "space": k = Keys.SPACE | |
| elif key_name == "backspace": k = Keys.BACK_SPACE | |
| if k: | |
| actions.send_keys(k) | |
| actions.perform() | |
| elif action_type == 'scroll': | |
| amount = action_data.get('amount', 2) | |
| direction = action_data.get('direction', 'down') | |
| scroll_y = amount * 100 | |
| if direction == 'up': scroll_y = -scroll_y | |
| self.driver.execute_script(f"window.scrollBy(0, {scroll_y});") | |
| elif action_type == 'open_url': | |
| url = action_data.get('url', '') | |
| if not url.startswith('http'): url = 'https://' + url | |
| self.driver.get(url) | |
| time.sleep(2) | |
| return f"Executed {action_type}" | |
| except Exception as e: | |
| return f"Action failed: {e}" | |
| def cleanup(self): | |
| try: self.driver.quit() | |
| except: pass | |
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |
| # ----------------------------------------------------------------------------- | |
| # PARSING LOGIC | |
| # ----------------------------------------------------------------------------- | |
| def parse_code_block(response: str) -> str: | |
| pattern = r"<code>\s*(.*?)\s*</code>" | |
| matches = re.findall(pattern, response, re.DOTALL) | |
| if matches: | |
| return matches[-1].strip() # Return the last code block | |
| return "" | |
| def parse_action_string(action_str: str) -> dict: | |
| """Parse string like 'click(x=0.5, y=0.5)' into a dict""" | |
| # Simple regex parsing for demonstration | |
| action_data = {} | |
| # 1. Coordinate actions: name(x=..., y=...) | |
| coord_match = re.match(r"(\w+)\s*\(\s*x\s*=\s*([0-9.]+)\s*,\s*y\s*=\s*([0-9.]+)\s*\)", action_str) | |
| if coord_match: | |
| return { | |
| "type": coord_match.group(1), | |
| "x": float(coord_match.group(2)), | |
| "y": float(coord_match.group(3)) | |
| } | |
| # 2. Open URL: open_url(url="...") | |
| url_match = re.match(r"open_url\s*\(\s*url\s*=\s*[\"'](.*?)[\"']\s*\)", action_str) | |
| if url_match: | |
| return {"type": "open_url", "url": url_match.group(1)} | |
| # 3. Type text: type_text(text="...") | |
| text_match = re.match(r"type_text\s*\(\s*text\s*=\s*[\"'](.*?)[\"']\s*\)", action_str) | |
| if text_match: | |
| return {"type": "type_text", "text": text_match.group(1)} | |
| # 4. Press key: press_key(key="...") | |
| key_match = re.match(r"press_key\s*\(\s*key\s*=\s*[\"'](.*?)[\"']\s*\)", action_str) | |
| if key_match: | |
| return {"type": "press_key", "key": key_match.group(1)} | |
| # 5. Scroll: scroll(amount=..., direction="...") | |
| if "scroll" in action_str: | |
| return {"type": "scroll", "amount": 2, "direction": "down"} # Default | |
| return {} | |
| # ----------------------------------------------------------------------------- | |
| # MAIN LOOP | |
| # ----------------------------------------------------------------------------- | |
| def agent_step(task_instruction: str, history: list, sandbox_state: dict): | |
| # Initialize sandbox if needed (handled via state in Gradio mostly, but for safety) | |
| if 'uuid' not in sandbox_state: | |
| sandbox_state['uuid'] = str(uuid.uuid4()) | |
| sandbox = SeleniumSandbox(WIDTH, HEIGHT) | |
| # Store sandbox instance reference globally or handle cleanup carefully | |
| # For this demo, we'll recreate/attach to session based on state if persisting, | |
| # but here we'll assume a persistent session for the run. | |
| # HACK: For Gradio state persistence with objects that can't be pickled easily, | |
| # we often use a global dict mapping UUID -> Sandbox | |
| sandbox_id = sandbox_state['uuid'] | |
| if sandbox_id not in SANDBOX_REGISTRY: | |
| SANDBOX_REGISTRY[sandbox_id] = SeleniumSandbox(WIDTH, HEIGHT) | |
| sandbox = SANDBOX_REGISTRY[sandbox_id] | |
| # 1. Get Screenshot | |
| screenshot = sandbox.get_screenshot() | |
| # 2. Construct Prompt | |
| # Convert history text to string context if needed | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": screenshot}, | |
| {"type": "text", "text": f"Instruction: {task_instruction}\nPrevious Actions: {history[-1] if history else 'None'}"} | |
| ] | |
| } | |
| ] | |
| # 3. Model Inference | |
| response = model.generate(messages) | |
| # 4. Parse Action | |
| action_code = parse_code_block(response) | |
| action_data = parse_action_string(action_code) | |
| log_entry = f"Step: {len(history)+1}\nModel Thought: {response}\nAction: {action_code}" | |
| # 5. Execute Action | |
| execution_result = "No valid action found" | |
| if action_data: | |
| execution_result = sandbox.execute_action(action_data) | |
| # Draw marker if coordinate action | |
| if 'x' in action_data: | |
| draw = ImageDraw.Draw(screenshot) | |
| x_px = action_data['x'] * WIDTH | |
| y_px = action_data['y'] * HEIGHT | |
| r = 10 | |
| draw.ellipse((x_px-r, y_px-r, x_px+r, y_px+r), outline="red", width=3) | |
| log_entry += f"\nResult: {execution_result}" | |
| history.append(log_entry) | |
| # Return updated screenshot and history | |
| return screenshot, history, sandbox_state | |
| # Global registry for sandboxes | |
| SANDBOX_REGISTRY = {} | |
| def cleanup_sandbox(sandbox_state): | |
| sid = sandbox_state.get('uuid') | |
| if sid and sid in SANDBOX_REGISTRY: | |
| SANDBOX_REGISTRY[sid].cleanup() | |
| del SANDBOX_REGISTRY[sid] | |
| return [], {} | |
| # ----------------------------------------------------------------------------- | |
| # GRADIO UI | |
| # ----------------------------------------------------------------------------- | |
| def run_task_loop(task, history, state): | |
| # This generator function runs the agent loop | |
| max_steps = 10 | |
| for i in range(max_steps): | |
| try: | |
| # Run one step | |
| screenshot, new_history, new_state = agent_step(task, history, state) | |
| history = new_history | |
| # Yield updates to UI | |
| # We yield the logs (joined) and the latest image | |
| logs_text = "\n\n" + "-"*40 + "\n\n".join(history) | |
| yield screenshot, logs_text, state | |
| # Check for termination (simplistic) | |
| if "Done" in history[-1] or "finished" in history[-1].lower(): | |
| break | |
| time.sleep(1) # Pause for visual effect | |
| except Exception as e: | |
| error_msg = f"Error in loop: {e}" | |
| history.append(error_msg) | |
| yield None, "\n".join(history), state | |
| break | |
| # UI Layout | |
| custom_css = """ | |
| #view_img { height: 600px; object-fit: contain; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| state = gr.State({}) | |
| history = gr.State([]) | |
| gr.Markdown("# 🤖 Fara CUA - Chrome Agent") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_input = gr.Textbox(label="Task Instruction", value="Go to google.com and search for 'SpaceX'") | |
| run_btn = gr.Button("Run Agent", variant="primary") | |
| clear_btn = gr.Button("Reset / Clear") | |
| with gr.Column(scale=2): | |
| browser_view = gr.Image(label="Live Browser View", elem_id="view_img", interactive=False) | |
| logs_output = gr.Textbox(label="Agent Logs", lines=15, interactive=False) | |
| # Event handlers | |
| run_btn.click( | |
| fn=run_task_loop, | |
| inputs=[task_input, history, state], | |
| outputs=[browser_view, logs_output, state] | |
| ) | |
| clear_btn.click( | |
| fn=cleanup_sandbox, | |
| inputs=[state], | |
| outputs=[history, state] | |
| ).then( | |
| lambda: (None, ""), | |
| outputs=[browser_view, logs_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |