Spaces:
Sleeping
Sleeping
| """MCP (Model Context Protocol) client integration for browser-use. | |
| This module provides integration between external MCP servers and browser-use's action registry. | |
| MCP tools are dynamically discovered and registered as browser-use actions. | |
| Example usage: | |
| from browser_use import Tools | |
| from browser_use.mcp.client import MCPClient | |
| tools = Tools() | |
| # Connect to an MCP server | |
| mcp_client = MCPClient( | |
| server_name="my-server", | |
| command="npx", | |
| args=["@mycompany/mcp-server@latest"] | |
| ) | |
| # Register all MCP tools as browser-use actions | |
| await mcp_client.register_to_tools(tools) | |
| # Now use with Agent as normal - MCP tools are available as actions | |
| """ | |
| import asyncio | |
| import logging | |
| import time | |
| from typing import Any | |
| from pydantic import BaseModel, ConfigDict, Field, create_model | |
| from browser_use.agent.views import ActionResult | |
| from browser_use.telemetry import MCPClientTelemetryEvent, ProductTelemetry | |
| from browser_use.tools.registry.service import Registry | |
| from browser_use.tools.service import Tools | |
| from browser_use.utils import get_browser_use_version | |
| logger = logging.getLogger(__name__) | |
| # Import MCP SDK | |
| from mcp import ClientSession, StdioServerParameters, types | |
| from mcp.client.stdio import stdio_client | |
| MCP_AVAILABLE = True | |
| class MCPClient: | |
| """Client for connecting to MCP servers and exposing their tools as browser-use actions.""" | |
| def __init__( | |
| self, | |
| server_name: str, | |
| command: str, | |
| args: list[str] | None = None, | |
| env: dict[str, str] | None = None, | |
| ): | |
| """Initialize MCP client. | |
| Args: | |
| server_name: Name of the MCP server (for logging and identification) | |
| command: Command to start the MCP server (e.g., "npx", "python") | |
| args: Arguments for the command (e.g., ["@playwright/mcp@latest"]) | |
| env: Environment variables for the server process | |
| """ | |
| self.server_name = server_name | |
| self.command = command | |
| self.args = args or [] | |
| self.env = env | |
| self.session: ClientSession | None = None | |
| self._stdio_task = None | |
| self._read_stream = None | |
| self._write_stream = None | |
| self._tools: dict[str, types.Tool] = {} | |
| self._registered_actions: set[str] = set() | |
| self._connected = False | |
| self._disconnect_event = asyncio.Event() | |
| self._telemetry = ProductTelemetry() | |
| async def connect(self) -> None: | |
| """Connect to the MCP server and discover available tools.""" | |
| if self._connected: | |
| logger.debug(f'Already connected to {self.server_name}') | |
| return | |
| start_time = time.time() | |
| error_msg = None | |
| try: | |
| logger.info(f"🔌 Connecting to MCP server '{self.server_name}': {self.command} {' '.join(self.args)}") | |
| # Create server parameters | |
| server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env) | |
| # Start stdio client in background task | |
| self._stdio_task = asyncio.create_task(self._run_stdio_client(server_params)) | |
| # Wait for connection to be established | |
| retries = 0 | |
| max_retries = 100 # 10 second timeout (increased for parallel test execution) | |
| while not self._connected and retries < max_retries: | |
| await asyncio.sleep(0.1) | |
| retries += 1 | |
| if not self._connected: | |
| error_msg = f"Failed to connect to MCP server '{self.server_name}' after {max_retries * 0.1} seconds" | |
| raise RuntimeError(error_msg) | |
| logger.info(f"📦 Discovered {len(self._tools)} tools from '{self.server_name}': {list(self._tools.keys())}") | |
| except Exception as e: | |
| error_msg = str(e) | |
| raise | |
| finally: | |
| # Capture telemetry for connect action | |
| duration = time.time() - start_time | |
| self._telemetry.capture( | |
| MCPClientTelemetryEvent( | |
| server_name=self.server_name, | |
| command=self.command, | |
| tools_discovered=len(self._tools), | |
| version=get_browser_use_version(), | |
| action='connect', | |
| duration_seconds=duration, | |
| error_message=error_msg, | |
| ) | |
| ) | |
| async def _run_stdio_client(self, server_params: StdioServerParameters): | |
| """Run the stdio client connection in a background task.""" | |
| try: | |
| async with stdio_client(server_params) as (read_stream, write_stream): | |
| self._read_stream = read_stream | |
| self._write_stream = write_stream | |
| # Create and initialize session | |
| async with ClientSession(read_stream, write_stream) as session: | |
| self.session = session | |
| # Initialize the connection | |
| await session.initialize() | |
| # Discover available tools | |
| tools_response = await session.list_tools() | |
| self._tools = {tool.name: tool for tool in tools_response.tools} | |
| # Mark as connected | |
| self._connected = True | |
| # Keep the connection alive until disconnect is called | |
| await self._disconnect_event.wait() | |
| except Exception as e: | |
| logger.error(f'MCP server connection error: {e}') | |
| self._connected = False | |
| raise | |
| finally: | |
| self._connected = False | |
| self.session = None | |
| async def disconnect(self) -> None: | |
| """Disconnect from the MCP server.""" | |
| if not self._connected: | |
| return | |
| start_time = time.time() | |
| error_msg = None | |
| try: | |
| logger.info(f"🔌 Disconnecting from MCP server '{self.server_name}'") | |
| # Signal disconnect | |
| self._connected = False | |
| self._disconnect_event.set() | |
| # Wait for stdio task to finish | |
| if self._stdio_task: | |
| try: | |
| await asyncio.wait_for(self._stdio_task, timeout=2.0) | |
| except TimeoutError: | |
| logger.warning(f"Timeout waiting for MCP server '{self.server_name}' to disconnect") | |
| self._stdio_task.cancel() | |
| try: | |
| await self._stdio_task | |
| except asyncio.CancelledError: | |
| pass | |
| self._tools.clear() | |
| self._registered_actions.clear() | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f'Error disconnecting from MCP server: {e}') | |
| finally: | |
| # Capture telemetry for disconnect action | |
| duration = time.time() - start_time | |
| self._telemetry.capture( | |
| MCPClientTelemetryEvent( | |
| server_name=self.server_name, | |
| command=self.command, | |
| tools_discovered=0, # Tools cleared on disconnect | |
| version=get_browser_use_version(), | |
| action='disconnect', | |
| duration_seconds=duration, | |
| error_message=error_msg, | |
| ) | |
| ) | |
| self._telemetry.flush() | |
| async def register_to_tools( | |
| self, | |
| tools: Tools, | |
| tool_filter: list[str] | None = None, | |
| prefix: str | None = None, | |
| ) -> None: | |
| """Register MCP tools as actions in the browser-use tools. | |
| Args: | |
| tools: Browser-use tools to register actions to | |
| tool_filter: Optional list of tool names to register (None = all tools) | |
| prefix: Optional prefix to add to action names (e.g., "playwright_") | |
| """ | |
| if not self._connected: | |
| await self.connect() | |
| registry = tools.registry | |
| for tool_name, tool in self._tools.items(): | |
| # Skip if not in filter | |
| if tool_filter and tool_name not in tool_filter: | |
| continue | |
| # Apply prefix if specified | |
| action_name = f'{prefix}{tool_name}' if prefix else tool_name | |
| # Skip if already registered | |
| if action_name in self._registered_actions: | |
| continue | |
| # Register the tool as an action | |
| self._register_tool_as_action(registry, action_name, tool) | |
| self._registered_actions.add(action_name) | |
| logger.info(f"✅ Registered {len(self._registered_actions)} MCP tools from '{self.server_name}' as browser-use actions") | |
| def _register_tool_as_action(self, registry: Registry, action_name: str, tool: Any) -> None: | |
| """Register a single MCP tool as a browser-use action. | |
| Args: | |
| registry: Browser-use registry to register action to | |
| action_name: Name for the registered action | |
| tool: MCP Tool object with schema information | |
| """ | |
| # Parse tool parameters to create Pydantic model | |
| param_fields = {} | |
| if tool.inputSchema: | |
| # MCP tools use JSON Schema for parameters | |
| properties = tool.inputSchema.get('properties', {}) | |
| required = set(tool.inputSchema.get('required', [])) | |
| for param_name, param_schema in properties.items(): | |
| # Convert JSON Schema type to Python type | |
| param_type = self._json_schema_to_python_type(param_schema, f'{action_name}_{param_name}') | |
| # Determine if field is required and handle defaults | |
| if param_name in required: | |
| default = ... # Required field | |
| else: | |
| # Optional field - make type optional and handle default | |
| param_type = param_type | None | |
| if 'default' in param_schema: | |
| default = param_schema['default'] | |
| else: | |
| default = None | |
| # Add field with description if available | |
| field_kwargs = {} | |
| if 'description' in param_schema: | |
| field_kwargs['description'] = param_schema['description'] | |
| param_fields[param_name] = (param_type, Field(default, **field_kwargs)) | |
| # Create Pydantic model for the tool parameters | |
| if param_fields: | |
| # Create a BaseModel class with proper configuration | |
| class ConfiguredBaseModel(BaseModel): | |
| model_config = ConfigDict(extra='forbid', validate_by_name=True, validate_by_alias=True) | |
| param_model = create_model(f'{action_name}_Params', __base__=ConfiguredBaseModel, **param_fields) | |
| else: | |
| # No parameters - create empty model | |
| param_model = None | |
| # Determine if this is a browser-specific tool | |
| is_browser_tool = tool.name.startswith('browser_') or 'page' in tool.name.lower() | |
| # Set up action filters | |
| domains = None | |
| # Note: page_filter has been removed since we no longer use Page objects | |
| # Browser tools filtering would need to be done via domain filters instead | |
| # Create async wrapper function for the MCP tool | |
| # Need to define function with explicit parameters to satisfy registry validation | |
| if param_model: | |
| # Type 1: Function takes param model as first parameter | |
| async def mcp_action_wrapper(params: param_model) -> ActionResult: # type: ignore[no-redef] | |
| """Wrapper function that calls the MCP tool.""" | |
| if not self.session or not self._connected: | |
| return ActionResult(error=f"MCP server '{self.server_name}' not connected", success=False) | |
| # Convert pydantic model to dict for MCP call | |
| tool_params = params.model_dump(exclude_none=True) | |
| logger.debug(f"🔧 Calling MCP tool '{tool.name}' with params: {tool_params}") | |
| start_time = time.time() | |
| error_msg = None | |
| try: | |
| # Call the MCP tool | |
| result = await self.session.call_tool(tool.name, tool_params) | |
| # Convert MCP result to ActionResult | |
| extracted_content = self._format_mcp_result(result) | |
| return ActionResult( | |
| extracted_content=extracted_content, | |
| long_term_memory=f"Used MCP tool '{tool.name}' from {self.server_name}", | |
| ) | |
| except Exception as e: | |
| error_msg = f"MCP tool '{tool.name}' failed: {str(e)}" | |
| logger.error(error_msg) | |
| return ActionResult(error=error_msg, success=False) | |
| finally: | |
| # Capture telemetry for tool call | |
| duration = time.time() - start_time | |
| self._telemetry.capture( | |
| MCPClientTelemetryEvent( | |
| server_name=self.server_name, | |
| command=self.command, | |
| tools_discovered=len(self._tools), | |
| version=get_browser_use_version(), | |
| action='tool_call', | |
| tool_name=tool.name, | |
| duration_seconds=duration, | |
| error_message=error_msg, | |
| ) | |
| ) | |
| else: | |
| # No parameters - empty function signature | |
| async def mcp_action_wrapper() -> ActionResult: # type: ignore[no-redef] | |
| """Wrapper function that calls the MCP tool.""" | |
| if not self.session or not self._connected: | |
| return ActionResult(error=f"MCP server '{self.server_name}' not connected", success=False) | |
| logger.debug(f"🔧 Calling MCP tool '{tool.name}' with no params") | |
| start_time = time.time() | |
| error_msg = None | |
| try: | |
| # Call the MCP tool with empty params | |
| result = await self.session.call_tool(tool.name, {}) | |
| # Convert MCP result to ActionResult | |
| extracted_content = self._format_mcp_result(result) | |
| return ActionResult( | |
| extracted_content=extracted_content, | |
| long_term_memory=f"Used MCP tool '{tool.name}' from {self.server_name}", | |
| ) | |
| except Exception as e: | |
| error_msg = f"MCP tool '{tool.name}' failed: {str(e)}" | |
| logger.error(error_msg) | |
| return ActionResult(error=error_msg, success=False) | |
| finally: | |
| # Capture telemetry for tool call | |
| duration = time.time() - start_time | |
| self._telemetry.capture( | |
| MCPClientTelemetryEvent( | |
| server_name=self.server_name, | |
| command=self.command, | |
| tools_discovered=len(self._tools), | |
| version=get_browser_use_version(), | |
| action='tool_call', | |
| tool_name=tool.name, | |
| duration_seconds=duration, | |
| error_message=error_msg, | |
| ) | |
| ) | |
| # Set function metadata for better debugging | |
| mcp_action_wrapper.__name__ = action_name | |
| mcp_action_wrapper.__qualname__ = f'mcp.{self.server_name}.{action_name}' | |
| # Register the action with browser-use | |
| description = tool.description or f'MCP tool from {self.server_name}: {tool.name}' | |
| # Use the registry's action decorator | |
| registry.action(description=description, param_model=param_model, domains=domains)(mcp_action_wrapper) | |
| logger.debug(f"✅ Registered MCP tool '{tool.name}' as action '{action_name}'") | |
| def _format_mcp_result(self, result: Any) -> str: | |
| """Format MCP tool result into a string for ActionResult. | |
| Args: | |
| result: Raw result from MCP tool call | |
| Returns: | |
| Formatted string representation of the result | |
| """ | |
| # Handle different MCP result formats | |
| if hasattr(result, 'content'): | |
| # Structured content response | |
| if isinstance(result.content, list): | |
| # Multiple content items | |
| parts = [] | |
| for item in result.content: | |
| if hasattr(item, 'text'): | |
| parts.append(item.text) | |
| elif hasattr(item, 'type') and item.type == 'text': | |
| parts.append(str(item)) | |
| else: | |
| parts.append(str(item)) | |
| return '\n'.join(parts) | |
| else: | |
| return str(result.content) | |
| elif isinstance(result, list): | |
| # List of content items | |
| parts = [] | |
| for item in result: | |
| if hasattr(item, 'text'): | |
| parts.append(item.text) | |
| else: | |
| parts.append(str(item)) | |
| return '\n'.join(parts) | |
| else: | |
| # Direct result or unknown format | |
| return str(result) | |
| def _json_schema_to_python_type(self, schema: dict, model_name: str = 'NestedModel') -> Any: | |
| """Convert JSON Schema type to Python type. | |
| Args: | |
| schema: JSON Schema definition | |
| model_name: Name for nested models | |
| Returns: | |
| Python type corresponding to the schema | |
| """ | |
| json_type = schema.get('type', 'string') | |
| # Basic type mapping | |
| type_mapping = { | |
| 'string': str, | |
| 'number': float, | |
| 'integer': int, | |
| 'boolean': bool, | |
| 'array': list, | |
| 'null': type(None), | |
| } | |
| # Handle enums (they're still strings) | |
| if 'enum' in schema: | |
| return str | |
| # Handle objects with nested properties | |
| if json_type == 'object': | |
| properties = schema.get('properties', {}) | |
| if properties: | |
| # Create nested pydantic model for objects with properties | |
| nested_fields = {} | |
| required_fields = set(schema.get('required', [])) | |
| for prop_name, prop_schema in properties.items(): | |
| # Recursively process nested properties | |
| prop_type = self._json_schema_to_python_type(prop_schema, f'{model_name}_{prop_name}') | |
| # Determine if field is required and handle defaults | |
| if prop_name in required_fields: | |
| default = ... # Required field | |
| else: | |
| # Optional field - make type optional and handle default | |
| prop_type = prop_type | None | |
| if 'default' in prop_schema: | |
| default = prop_schema['default'] | |
| else: | |
| default = None | |
| # Add field with description if available | |
| field_kwargs = {} | |
| if 'description' in prop_schema: | |
| field_kwargs['description'] = prop_schema['description'] | |
| nested_fields[prop_name] = (prop_type, Field(default, **field_kwargs)) | |
| # Create a BaseModel class with proper configuration | |
| class ConfiguredBaseModel(BaseModel): | |
| model_config = ConfigDict(extra='forbid', validate_by_name=True, validate_by_alias=True) | |
| try: | |
| # Create and return nested pydantic model | |
| return create_model(model_name, __base__=ConfiguredBaseModel, **nested_fields) | |
| except Exception as e: | |
| logger.error(f'Failed to create nested model {model_name}: {e}') | |
| logger.debug(f'Fields: {nested_fields}') | |
| # Fallback to basic dict if model creation fails | |
| return dict | |
| else: | |
| # Object without properties - just return dict | |
| return dict | |
| # Handle arrays with specific item types | |
| if json_type == 'array': | |
| if 'items' in schema: | |
| # Get the item type recursively | |
| item_type = self._json_schema_to_python_type(schema['items'], f'{model_name}_item') | |
| # Return properly typed list | |
| return list[item_type] | |
| else: | |
| # Array without item type specification | |
| return list | |
| # Get base type for non-object types | |
| base_type = type_mapping.get(json_type, str) | |
| # Handle nullable/optional types | |
| if schema.get('nullable', False) or json_type == 'null': | |
| return base_type | None | |
| return base_type | |
| async def __aenter__(self): | |
| """Async context manager entry.""" | |
| await self.connect() | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| """Async context manager exit.""" | |
| await self.disconnect() | |