# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ MCP Client classes for tool-calling environments. This module provides async client classes for interacting with MCP-enabled environments: - MCPClientBase: Base class with shared tool discovery - MCPToolClient: Client for tool-calling style (one tool per step) These clients abstract away the MCP protocol details, providing a clean interface for listing and calling tools on remote environments. All clients are async by default. Architecture Overview:: ┌─────────────────────────────────────────────────────────┐ │ HTTPEnvServer │ ├─────────────────────────────────────────────────────────┤ │ Simulation Mode (default): │ │ /ws → OpenEnv protocol (reset/step/state) │ │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ │ /reset, /step, /state → HTTP endpoints │ ├─────────────────────────────────────────────────────────┤ │ Production Mode (use_production_mode=True): │ │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ │ Bypasses step() for direct tool access │ └─────────────────────────────────────────────────────────┘ Client Usage: MCPToolClient (default) → /ws (step-based, with rewards) MCPToolClient (production) → /mcp (direct tool access, no rewards) Example (async): >>> from openenv.core.mcp_client import MCPToolClient >>> >>> async with MCPToolClient(base_url="http://localhost:8000") as env: ... # Discover available tools ... tools = await env.list_tools() ... print([t.name for t in tools]) ... ... # Call a tool ... result = await env.call_tool("echo_message", message="Hello!") ... print(result) Example (sync wrapper): >>> env = MCPToolClient(base_url="http://localhost:8000").sync() >>> with env: ... tools = env.list_tools() ... result = env.call_tool("echo_message", message="Hello!") """ from typing import Any, Dict, List, Optional from .client_types import StepResult from .env_client import EnvClient from .env_server.mcp_types import ( CallToolAction, CallToolObservation, ListToolsAction, ListToolsObservation, Tool, ToolError, ) from .env_server.types import Observation, State class MCPClientBase(EnvClient[Any, Observation, State]): """ Base class for MCP clients with tool discovery. This class provides the common `list_tools()` method for discovering available tools from an MCP-enabled environment. Subclasses implement specific interaction patterns (tool-calling or CodeAct). Attributes: _tools_cache: Cached list of tools (populated on first `list_tools()` call) """ def __init__( self, base_url: str, connect_timeout_s: float = 10.0, message_timeout_s: float = 60.0, provider: Optional[Any] = None, mode: Optional[str] = None, ): """ Initialize MCP client. Args: base_url: Base URL of the environment server (http:// or ws://). connect_timeout_s: Timeout for establishing WebSocket connection. message_timeout_s: Timeout for receiving responses to messages. provider: Optional container/runtime provider for lifecycle management. mode: Communication mode. Must be 'production' for MCP clients. Defaults to 'production'. """ # MCPClientBase defaults to production mode, but allow override for validation if mode is None: mode = "production" # Validate that mode is production mode_lower = mode.lower() if mode_lower != "production": raise ValueError( f"MCPToolClient only supports 'production' mode, got '{mode}'. " f"Use GenericEnvClient for simulation mode." ) super().__init__( base_url=base_url, connect_timeout_s=connect_timeout_s, message_timeout_s=message_timeout_s, provider=provider, mode=mode, ) self._tools_cache: Optional[List[Tool]] = None self.use_production_mode = False async def list_tools(self, use_cache: bool = True) -> List[Tool]: """ Discover available tools from the environment. Args: use_cache: If True, return cached tools if available. Set to False to force a fresh request. Returns: List of Tool objects with name, description, and input_schema. Example: >>> tools = await env.list_tools() >>> for tool in tools: ... print(f"{tool.name}: {tool.description}") """ if use_cache and self._tools_cache is not None: return self._tools_cache # Use production mode HTTP endpoint if enabled if self.use_production_mode: import requests # Convert ws:// URL to http:// URL url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") # Remove /ws suffix if present and add /mcp url = url.rstrip("/ws").rstrip("/") + "/mcp" try: response = requests.post( url, json={ "jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 1, }, ) data = response.json() if "result" in data and "tools" in data["result"]: tools = [ Tool( name=t.get("name", ""), description=t.get("description", ""), input_schema=t.get( "input_schema", t.get("inputSchema", {}) ), ) for t in data["result"]["tools"] ] self._tools_cache = tools return tools except Exception: # If HTTP request fails, return empty list pass return [] result = await self.step(ListToolsAction()) self._tools_cache = result.observation.tools return self._tools_cache def _step_payload(self, action: Any) -> Dict[str, Any]: """Convert an Action object to the JSON data expected by the env server.""" if isinstance(action, ListToolsAction): return {"type": "list_tools"} elif isinstance(action, CallToolAction): return { "type": "call_tool", "tool_name": action.tool_name, "arguments": action.arguments, } else: # For unknown actions, try to serialize as dict if hasattr(action, "model_dump"): return action.model_dump() return {"action": str(action)} def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: """Convert a JSON response from the env server to StepResult[Observation].""" obs_data = payload.get("observation", {}) # Check if this is a ListToolsObservation if "tools" in obs_data: tools = [ Tool( name=t.get("name", ""), description=t.get("description", ""), input_schema=t.get("input_schema", t.get("inputSchema", {})), ) for t in obs_data.get("tools", []) ] observation = ListToolsObservation( tools=tools, done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: error = None if obs_data.get("error"): error = ToolError(**obs_data["error"]) observation = CallToolObservation( tool_name=obs_data.get("tool_name", ""), result=obs_data.get("result"), error=error, done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> State: """Convert a JSON response from the state endpoint to a State object.""" return State( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), ) class MCPToolClient(MCPClientBase): """ Async client for tool-calling style MCP interactions. Each step invokes a single tool. Use this for traditional function-calling agent patterns where the agent decides which tool to call next. This client provides convenience methods for tool discovery and invocation: - `list_tools()`: Get all available tools with their schemas - `call_tool(name, **kwargs)`: Invoke a tool by name with arguments Example (async): >>> async with MCPToolClient(base_url="http://localhost:8000") as env: ... # Reset the environment ... await env.reset() ... ... # Discover available tools ... tools = await env.list_tools() ... print([t.name for t in tools]) # ['echo_message', 'echo_with_length'] ... ... # Call a tool directly ... result = await env.call_tool("echo_message", message="Hello!") ... print(result) # "Hello!" ... ... # Or use the full action interface ... from openenv.core.env_server.mcp_types import CallToolAction ... step_result = await env.step(CallToolAction( ... tool_name="echo_with_length", ... arguments={"message": "Test"} ... )) ... print(step_result.observation.result) Example (sync wrapper): >>> env = MCPToolClient(base_url="http://localhost:8000").sync() >>> with env: ... tools = env.list_tools() ... result = env.call_tool("echo_message", message="Hello!") """ async def call_tool(self, name: str, **kwargs: Any) -> Any: """ Call a tool by name. This is a convenience method that creates a CallToolAction, executes it, and returns the result directly. For more control, use `step()` with a CallToolAction directly. Args: name: Name of the tool to invoke (must match a tool from `list_tools()`). **kwargs: Arguments to pass to the tool. Must match the tool's input_schema. Returns: The tool's result. The type depends on the tool being called. Raises: RuntimeError: If the server returns an error response. Example: >>> result = await env.call_tool("add", a=5, b=3) >>> print(result) # 8 >>> >>> result = await env.call_tool("greet", name="Claude") >>> print(result) # "Hello, Claude!" """ action = CallToolAction(tool_name=name, arguments=kwargs) result = await self.step(action) obs = result.observation # Check for transport/framework errors if isinstance(obs, CallToolObservation) and obs.error is not None: raise RuntimeError( f"Tool '{name}' failed: {obs.error.message} " f"(type: {obs.error.error_type.value})" ) # Return the result if isinstance(obs, CallToolObservation): result = obs.result # Handle FastMCP CallToolResult objects # - As object: has .data attribute # - As dict (from JSON): has "data" key if hasattr(result, "data"): return result.data if isinstance(result, dict) and "data" in result: return result["data"] return result # Fallback for unexpected observation types return obs async def get_tool(self, name: str) -> Optional[Tool]: """ Get a specific tool by name. Args: name: Name of the tool to find. Returns: The Tool object if found, None otherwise. Example: >>> tool = await env.get_tool("echo_message") >>> if tool: ... print(tool.description) ... print(tool.input_schema) """ tools = await self.list_tools() for tool in tools: if tool.name == name: return tool return None async def has_tool(self, name: str) -> bool: """ Check if a tool exists. Args: name: Name of the tool to check. Returns: True if the tool exists, False otherwise. """ return await self.get_tool(name) is not None