Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| import shutil | |
| import time | |
| import uuid | |
| import tempfile | |
| import atexit | |
| import unicodedata | |
| from io import BytesIO | |
| from threading import Timer | |
| from typing import Any, Dict, List, Optional | |
| from datetime import datetime | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from dotenv import load_dotenv | |
| from PIL import Image, ImageDraw | |
| # Selenium Imports | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.service import Service as ChromeService | |
| from selenium.webdriver.chrome.options import Options as ChromeOptions | |
| from selenium.webdriver.common.action_chains import ActionChains | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.common.keys import Keys | |
| from webdriver_manager.chrome import ChromeDriverManager | |
| # Smolagents imports | |
| from smolagents import CodeAgent, tool, AgentImage | |
| from smolagents.memory import ActionStep, TaskStep | |
| from smolagents.models import ChatMessage, Model, MessageRole | |
| from smolagents.gradio_ui import GradioUI, stream_to_gradio | |
| from smolagents.monitoring import LogLevel | |
| # Transformers for Fara Model | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoProcessor, | |
| ) | |
| from qwen_vl_utils import process_vision_info | |
| load_dotenv(override=True) | |
| # ----------------------------------------------------------------------------- | |
| # CONFIGURATION & CONSTANTS | |
| # ----------------------------------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") | |
| if HF_TOKEN: | |
| from huggingface_hub import login | |
| login(token=HF_TOKEN) | |
| # Browser Sandbox Config | |
| WIDTH = 1024 | |
| HEIGHT = 768 | |
| TMP_DIR = "./tmp/" | |
| if not os.path.exists(TMP_DIR): | |
| os.makedirs(TMP_DIR) | |
| # ----------------------------------------------------------------------------- | |
| # MODEL INITIALIZATION (Fara-7B / Qwen2.5-VL) | |
| # ----------------------------------------------------------------------------- | |
| print("Loading Fara Model... This may take a moment.") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID_F = "microsoft/Fara-7B" | |
| # Global model variables | |
| model_f = None | |
| processor_f = None | |
| try: | |
| processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True) | |
| model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_F, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto", | |
| ) | |
| print(f"Fara Model loaded successfully on {DEVICE}") | |
| except Exception as e: | |
| print(f"Error loading Fara Model: {e}") | |
| print("Falling back to Qwen/Qwen2.5-VL-7B-Instruct...") | |
| try: | |
| MODEL_ID_F = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| processor_f = AutoProcessor.from_pretrained(MODEL_ID_F, trust_remote_code=True) | |
| model_f = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_F, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto", | |
| ) | |
| print(f"Fallback Model ({MODEL_ID_F}) loaded successfully.") | |
| except Exception as inner_e: | |
| print(f"Critical error loading model: {inner_e}") | |
| # ----------------------------------------------------------------------------- | |
| # GPU ISOLATED INFERENCE FUNCTION | |
| # ----------------------------------------------------------------------------- | |
| def run_model_inference(formatted_messages, max_tokens=1024, stop_sequences=None): | |
| """ | |
| Runs inference on the GPU worker. | |
| """ | |
| global model_f, processor_f | |
| if model_f is None: | |
| raise ValueError("Model is not loaded.") | |
| text = processor_f.apply_chat_template( | |
| formatted_messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(formatted_messages) | |
| inputs = processor_f( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(model_f.device) | |
| with torch.no_grad(): | |
| generated_ids = model_f.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| stop_strings=stop_sequences, | |
| tokenizer=processor_f.tokenizer, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor_f.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| return output_text | |
| class FaraLocalModel(Model): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def __call__( | |
| self, | |
| messages: List[Dict[str, Any]], | |
| stop_sequences: Optional[List[str]] = None, | |
| **kwargs, | |
| ) -> ChatMessage: | |
| formatted_messages = [] | |
| for msg in messages: | |
| role = msg["role"] | |
| content = msg["content"] | |
| new_content = [] | |
| if isinstance(content, str): | |
| new_content.append({"type": "text", "text": content}) | |
| elif isinstance(content, list): | |
| for item in content: | |
| if isinstance(item, str): | |
| new_content.append({"type": "text", "text": item}) | |
| elif isinstance(item, dict): | |
| if "type" in item: | |
| if item["type"] == "image": | |
| val = item.get("image") or item.get("url") or item.get("path") | |
| new_content.append({"type": "image", "image": val}) | |
| else: | |
| new_content.append(item) | |
| formatted_messages.append({"role": role, "content": new_content}) | |
| output_text = run_model_inference( | |
| formatted_messages=formatted_messages, | |
| max_tokens=kwargs.get("max_tokens", 1024), | |
| stop_sequences=stop_sequences | |
| ) | |
| return ChatMessage( | |
| role=MessageRole.ASSISTANT, | |
| content=output_text, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # SELENIUM CHROME SANDBOX | |
| # ----------------------------------------------------------------------------- | |
| def get_system_chrome_path(): | |
| # Common paths for chromium in Linux/HF Spaces | |
| paths = [ | |
| "/usr/bin/chromium", | |
| "/usr/bin/chromium-browser", | |
| "/usr/bin/google-chrome", | |
| ] | |
| for p in paths: | |
| if os.path.exists(p): | |
| return p | |
| return None | |
| class SeleniumSandbox: | |
| def __init__(self, width=1024, height=768): | |
| self.width = width | |
| self.height = height | |
| self.tmp_dir = tempfile.mkdtemp(prefix="chrome_sandbox_") | |
| # Setup Chrome Options | |
| chrome_opts = ChromeOptions() | |
| # Use system binary if available (fixes status 127 in HF Spaces) | |
| binary_path = get_system_chrome_path() | |
| if binary_path: | |
| print(f"Using system Chrome binary at: {binary_path}") | |
| chrome_opts.binary_location = binary_path | |
| chrome_opts.add_argument("--headless=new") | |
| chrome_opts.add_argument(f"--user-data-dir={self.tmp_dir}") | |
| chrome_opts.add_argument(f"--window-size={width},{height}") | |
| chrome_opts.add_argument("--no-sandbox") # Crucial for containers | |
| chrome_opts.add_argument("--disable-dev-shm-usage") # Crucial for containers | |
| chrome_opts.add_argument("--disable-gpu") | |
| chrome_opts.add_argument("--disable-extensions") | |
| # Initialize Driver | |
| try: | |
| # Check for system driver first | |
| system_driver_path = "/usr/bin/chromedriver" | |
| if os.path.exists(system_driver_path): | |
| print(f"Using system ChromeDriver at: {system_driver_path}") | |
| service = ChromeService(executable_path=system_driver_path) | |
| else: | |
| print("Using webdriver_manager to install ChromeDriver...") | |
| service = ChromeService(ChromeDriverManager().install()) | |
| self.driver = webdriver.Chrome(service=service, options=chrome_opts) | |
| self.driver.set_window_size(width, height) | |
| self.driver.get("about:blank") | |
| print(f"Selenium Chrome Driver started successfully.") | |
| except Exception as e: | |
| print(f"Failed to initialize Selenium: {e}") | |
| self.cleanup() | |
| raise e | |
| def get_screenshot(self): | |
| """Returns screenshot as PIL Image""" | |
| png_data = self.driver.get_screenshot_as_png() | |
| return Image.open(BytesIO(png_data)) | |
| def move_mouse_and_click(self, x, y, click_type="left"): | |
| try: | |
| body = self.driver.find_element(By.TAG_NAME, "body") | |
| actions = ActionChains(self.driver) | |
| actions.move_to_element_with_offset(body, 0, 0) | |
| actions.move_by_offset(x, y) | |
| if click_type == "left": | |
| actions.click() | |
| elif click_type == "right": | |
| actions.context_click() | |
| elif click_type == "double": | |
| actions.double_click() | |
| actions.perform() | |
| except Exception as e: | |
| print(f"Error in move_mouse_and_click: {e}") | |
| def drag_and_drop(self, x1, y1, x2, y2): | |
| try: | |
| body = self.driver.find_element(By.TAG_NAME, "body") | |
| actions = ActionChains(self.driver) | |
| actions.move_to_element_with_offset(body, 0, 0) | |
| actions.move_by_offset(x1, y1) | |
| actions.click_and_hold() | |
| actions.move_by_offset(x2 - x1, y2 - y1) | |
| actions.release() | |
| actions.perform() | |
| except Exception as e: | |
| print(f"Error in drag_and_drop: {e}") | |
| def type_text(self, text): | |
| actions = ActionChains(self.driver) | |
| actions.send_keys(text) | |
| actions.perform() | |
| def press_key(self, key_name): | |
| try: | |
| k = getattr(Keys, key_name.upper(), None) | |
| if not k: | |
| if key_name.lower() == "enter": k = Keys.ENTER | |
| elif key_name.lower() == "space": k = Keys.SPACE | |
| elif key_name.lower() == "backspace": k = Keys.BACK_SPACE | |
| elif key_name.lower() == "esc": k = Keys.ESCAPE | |
| else: k = key_name | |
| actions = ActionChains(self.driver) | |
| actions.send_keys(k) | |
| actions.perform() | |
| except Exception as e: | |
| print(f"Error pressing key: {e}") | |
| def scroll(self, amount, direction="down"): | |
| try: | |
| scroll_y = amount * 100 | |
| if direction == "up": | |
| scroll_y = -scroll_y | |
| self.driver.execute_script(f"window.scrollBy(0, {scroll_y});") | |
| except Exception as e: | |
| print(f"Error scrolling: {e}") | |
| def cleanup(self): | |
| try: | |
| if hasattr(self, 'driver'): | |
| self.driver.quit() | |
| except: | |
| pass | |
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |
| # ----------------------------------------------------------------------------- | |
| # AGENT SETUP | |
| # ----------------------------------------------------------------------------- | |
| SYSTEM_PROMPT_TEMPLATE = """You are a browser automation assistant controlling a Google Chrome web browser. The current date is <<current_date>>. | |
| <action process> | |
| You will be given a task to solve in several steps. At each step you will perform an action. | |
| After each action, you'll receive an updated screenshot of the browser. | |
| Then you will proceed as follows, with these sections: don't skip any! | |
| Short term goal: ... | |
| What I see: ... | |
| Reflection: ... | |
| Action: | |
| ```python | |
| click(254, 308) | |
| ```<end_code> | |
| Always format your action ('Action:' part) as Python code blocks as shown above. | |
| </action_process> | |
| <tools> | |
| On top of performing computations in the Python code snippets that you create, you only have access to these tools to interact with the browser: | |
| {%- for tool in tools.values() %} | |
| - {{ tool.name }}: {{ tool.description }} | |
| Takes inputs: {{tool.inputs}} | |
| Returns an output of type: {{tool.output_type}} | |
| {%- endfor %} | |
| </tools> | |
| <click_guidelines> | |
| The browser has a resolution of <<resolution_x>>x<<resolution_y>> pixels. | |
| NEVER USE HYPOTHETIC OR ASSUMED COORDINATES, USE TRUE COORDINATES that you can see from the screenshot. | |
| Use precise coordinates based on the current screenshot. | |
| Whenever you click, MAKE SURE to click in the middle of the button, text, link or any other clickable element. | |
| In the screenshot you will see a green crosshair displayed over the position of your last click. | |
| </click_guidelines> | |
| <general_guidelines> | |
| Execute one action at a time. | |
| Use `open_url` to navigate to websites. | |
| Use `click` to navigate links and interface elements. | |
| Use `type_text` to input into forms. | |
| Use `scroll` to see more content. | |
| If you get stuck, try using `open_url` to search on Google. | |
| </general_guidelines> | |
| """.replace("<<current_date>>", datetime.now().strftime("%A, %d-%B-%Y")) | |
| def draw_marker_on_image(image_copy, click_coordinates): | |
| x, y = click_coordinates | |
| draw = ImageDraw.Draw(image_copy) | |
| cross_size, linewidth = 10, 3 | |
| # Draw cross | |
| draw.line((x - cross_size, y, x + cross_size, y), fill="green", width=linewidth) | |
| draw.line((x, y - cross_size, x, y + cross_size), fill="green", width=linewidth) | |
| draw.ellipse( | |
| (x - cross_size * 2, y - cross_size * 2, x + cross_size * 2, y + cross_size * 2), | |
| outline="green", | |
| width=linewidth, | |
| ) | |
| return image_copy | |
| class SeleniumVisionAgent(CodeAgent): | |
| """Agent for Browser automation with Selenium and Vision""" | |
| def __init__( | |
| self, | |
| model: Model, | |
| data_dir: str, | |
| sandbox: SeleniumSandbox, | |
| max_steps: int = 20, | |
| verbosity_level: LogLevel = 2, | |
| **kwargs, | |
| ): | |
| self.sandbox = sandbox | |
| self.data_dir = data_dir | |
| # Initialize | |
| print(f"Browser size: {self.sandbox.width}x{self.sandbox.height}") | |
| os.makedirs(self.data_dir, exist_ok=True) | |
| super().__init__( | |
| tools=[], | |
| model=model, | |
| max_steps=max_steps, | |
| verbosity_level=verbosity_level, | |
| **kwargs, | |
| ) | |
| self.prompt_templates["system_prompt"] = SYSTEM_PROMPT_TEMPLATE.replace( | |
| "<<resolution_x>>", str(self.sandbox.width) | |
| ).replace("<<resolution_y>>", str(self.sandbox.height)) | |
| self.register_tools() | |
| self.step_callbacks.append(self.take_screenshot_callback) | |
| def register_tools(self): | |
| def click(x: int, y: int) -> str: | |
| """ | |
| Performs a left-click at the specified coordinates. | |
| Args: | |
| x: The x coordinate (horizontal position). | |
| y: The y coordinate (vertical position). | |
| """ | |
| self.sandbox.move_mouse_and_click(x, y, "left") | |
| self.click_coordinates = [x, y] | |
| return f"Clicked at ({x}, {y})" | |
| def right_click(x: int, y: int) -> str: | |
| """ | |
| Performs a right-click at the specified coordinates. | |
| Args: | |
| x: The x coordinate. | |
| y: The y coordinate. | |
| """ | |
| self.sandbox.move_mouse_and_click(x, y, "right") | |
| self.click_coordinates = [x, y] | |
| return f"Right-clicked at ({x}, {y})" | |
| def double_click(x: int, y: int) -> str: | |
| """ | |
| Performs a double-click at the specified coordinates. | |
| Args: | |
| x: The x coordinate. | |
| y: The y coordinate. | |
| """ | |
| self.sandbox.move_mouse_and_click(x, y, "double") | |
| self.click_coordinates = [x, y] | |
| return f"Double-clicked at ({x}, {y})" | |
| def type_text(text: str) -> str: | |
| """ | |
| Types the specified text. | |
| Args: | |
| text: The text to type. | |
| """ | |
| clean_text = unicodedata.normalize("NFD", text) | |
| self.sandbox.type_text(clean_text) | |
| return f"Typed text: '{clean_text}'" | |
| def press_key(key: str) -> str: | |
| """ | |
| Presses a keyboard key (e.g., 'enter', 'backspace', 'esc'). | |
| Args: | |
| key: The key name. | |
| """ | |
| self.sandbox.press_key(key) | |
| return f"Pressed key: {key}" | |
| def drag_and_drop(x1: int, y1: int, x2: int, y2: int) -> str: | |
| """ | |
| Drags from (x1, y1) and drops at (x2, y2). | |
| Args: | |
| x1: Start x coordinate. | |
| y1: Start y coordinate. | |
| x2: End x coordinate. | |
| y2: End y coordinate. | |
| """ | |
| self.sandbox.drag_and_drop(x1, y1, x2, y2) | |
| return f"Dragged from [{x1}, {y1}] to [{x2}, {y2}]" | |
| def scroll(amount: int, direction: str = "down") -> str: | |
| """ | |
| Scrolls the page. | |
| Args: | |
| amount: The amount to scroll (1-10). | |
| direction: "up" or "down". | |
| """ | |
| self.sandbox.scroll(amount, direction) | |
| return f"Scrolled {direction} by {amount}" | |
| def wait(seconds: float) -> str: | |
| """ | |
| Waits for the specified number of seconds. | |
| Args: | |
| seconds: The duration to wait. | |
| """ | |
| time.sleep(seconds) | |
| return f"Waited for {seconds} seconds" | |
| def open_url(url: str) -> str: | |
| """ | |
| Navigates the browser to the specified URL. | |
| Args: | |
| url: The URL to open. | |
| """ | |
| if not url.startswith(("http://", "https://")): | |
| url = "https://" + url | |
| try: | |
| self.sandbox.driver.get(url) | |
| time.sleep(2) | |
| title = self.sandbox.driver.title | |
| return f"Opened URL: {url}. Page Title: {title}" | |
| except Exception as e: | |
| return f"Failed to open URL: {e}" | |
| def go_back() -> str: | |
| """ | |
| Goes back to the previous page in history. | |
| """ | |
| self.sandbox.driver.back() | |
| return "Went back one page" | |
| self.tools["click"] = click | |
| self.tools["right_click"] = right_click | |
| self.tools["double_click"] = double_click | |
| self.tools["type_text"] = type_text | |
| self.tools["press_key"] = press_key | |
| self.tools["drag_and_drop"] = drag_and_drop | |
| self.tools["scroll"] = scroll | |
| self.tools["wait"] = wait | |
| self.tools["open_url"] = open_url | |
| self.tools["go_back"] = go_back | |
| def take_screenshot_callback(self, memory_step: ActionStep, agent=None) -> None: | |
| """Takes a screenshot and saves it to memory""" | |
| current_step = memory_step.step_number | |
| time.sleep(1.0) # Wait for renders | |
| image = self.sandbox.get_screenshot() | |
| # Save to disk | |
| screenshot_path = os.path.join(self.data_dir, f"step_{current_step:03d}.png") | |
| image.save(screenshot_path) | |
| image_copy = image.copy() | |
| if getattr(self, "click_coordinates", None): | |
| image_copy = draw_marker_on_image(image_copy, self.click_coordinates) | |
| self.last_marked_screenshot = AgentImage(screenshot_path) | |
| # Cleanup old images in memory to save RAM | |
| for previous_memory_step in agent.memory.steps: | |
| if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= current_step - 1: | |
| previous_memory_step.observations_images = None | |
| elif isinstance(previous_memory_step, TaskStep): | |
| previous_memory_step.task_images = None | |
| memory_step.observations_images = [image_copy] | |
| self.click_coordinates = None | |
| def create_agent(data_dir, sandbox): | |
| model = FaraLocalModel() | |
| return SeleniumVisionAgent( | |
| model=model, | |
| data_dir=data_dir, | |
| sandbox=sandbox, | |
| max_steps=30, | |
| verbosity_level=2 | |
| ) | |
| def generate_interaction_id(session_uuid): | |
| return f"{session_uuid}_{int(time.time())}" | |
| def get_agent_summary_erase_images(agent): | |
| for memory_step in agent.memory.steps: | |
| if hasattr(memory_step, "observations_images"): | |
| memory_step.observations_images = None | |
| if hasattr(memory_step, "task_images"): | |
| memory_step.task_images = None | |
| return agent.write_memory_to_messages() | |
| def save_final_status(folder, status: str, summary, error_message=None) -> None: | |
| try: | |
| with open(os.path.join(folder, "metadata.json"), "w") as output_file: | |
| output_file.write( | |
| json.dumps( | |
| {"status": status, "summary": summary, "error_message": error_message}, | |
| default=str | |
| ) | |
| ) | |
| except Exception as e: | |
| print(f"Failed to save metadata: {e}") | |
| # ----------------------------------------------------------------------------- | |
| # UI & APP | |
| # ----------------------------------------------------------------------------- | |
| custom_css = """ | |
| .modal-container { margin: var(--size-16) auto!important; } | |
| .browser-container { position: relative; width: 100%; height: 600px; border: 1px solid #444; background: #222; display: flex; align-items: center; justify-content: center; overflow: hidden; } | |
| .browser-image { max-width: 100%; max-height: 100%; object-fit: contain; } | |
| #chatbot { height: 800px!important; } | |
| """ | |
| class EnrichedGradioUI(GradioUI): | |
| def interact_with_agent( | |
| self, | |
| task_input, | |
| stored_messages, | |
| session_state, | |
| session_uuid, | |
| consent_storage, | |
| request: gr.Request, | |
| ): | |
| interaction_id = generate_interaction_id(session_uuid) | |
| data_dir = os.path.join(TMP_DIR, interaction_id) | |
| sandbox = SeleniumSandbox(width=WIDTH, height=HEIGHT) | |
| agent = create_agent(data_dir=data_dir, sandbox=sandbox) | |
| session_state["agent"] = agent | |
| try: | |
| stored_messages.append(gr.ChatMessage(role="user", content=task_input)) | |
| yield stored_messages, None | |
| screenshot = sandbox.get_screenshot() | |
| for msg in stream_to_gradio( | |
| agent, | |
| task=task_input, | |
| task_images=[screenshot], | |
| reset_agent_memory=False, | |
| ): | |
| if hasattr(agent, "last_marked_screenshot") and msg.content == "-----": | |
| stored_messages.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content={ | |
| "path": agent.last_marked_screenshot.to_string(), | |
| "mime_type": "image/png", | |
| }, | |
| ) | |
| ) | |
| yield stored_messages, agent.last_marked_screenshot.to_string() | |
| else: | |
| stored_messages.append(msg) | |
| yield stored_messages, None | |
| if consent_storage: | |
| summary = get_agent_summary_erase_images(agent) | |
| save_final_status(data_dir, "completed", summary=summary) | |
| yield stored_messages, None | |
| except Exception as e: | |
| error_message = f"Error in interaction: {str(e)}" | |
| print(error_message) | |
| stored_messages.append( | |
| gr.ChatMessage(role="assistant", content="Run failed:\n" + error_message) | |
| ) | |
| yield stored_messages, None | |
| finally: | |
| sandbox.cleanup() | |
| theme = gr.themes.Default( | |
| font=["Oxanium", "sans-serif"], primary_hue="amber", secondary_hue="blue" | |
| ) | |
| with gr.Blocks(theme=theme, css=custom_css) as demo: | |
| session_uuid_state = gr.State(lambda: str(uuid.uuid4())) | |
| session_state = gr.State({}) | |
| stored_messages = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Fara CUA - Chrome Agent 🌐") | |
| task_input = gr.Textbox( | |
| value="Go to google.com and search for 'Hugging Face'", | |
| label="Task", | |
| lines=3 | |
| ) | |
| run_btn = gr.Button("Start Task", variant="primary") | |
| stop_btn = gr.Button("Stop", variant="secondary") | |
| consent_storage = gr.Checkbox(label="Save logs locally?", value=True) | |
| gr.Examples( | |
| examples=[ | |
| "Go to google.com and search for 'Hugging Face', then click the first link.", | |
| "Go to wikipedia.org, type 'Python' in search, and click the search button.", | |
| ], | |
| inputs=task_input | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| chatbot_display = gr.Chatbot( | |
| label="Agent Trace", | |
| type="messages", | |
| height=800, | |
| avatar_images=(None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png"), | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Latest Browser View") | |
| live_browser_view = gr.Image( | |
| label="Browser View", | |
| type="filepath", | |
| interactive=False, | |
| height=600 | |
| ) | |
| agent_ui = EnrichedGradioUI(CodeAgent(tools=[], model=Model(), name="init")) | |
| def interrupt_agent(session_state): | |
| if "agent" in session_state and hasattr(session_state["agent"], "interrupt_switch"): | |
| session_state["agent"].interrupt_switch = True | |
| return "Interrupted" | |
| run_event = run_btn.click( | |
| fn=agent_ui.interact_with_agent, | |
| inputs=[ | |
| task_input, | |
| stored_messages, | |
| session_state, | |
| session_uuid_state, | |
| consent_storage, | |
| ], | |
| outputs=[chatbot_display, live_browser_view] | |
| ) | |
| stop_btn.click(fn=interrupt_agent, inputs=[session_state], outputs=[]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |