# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ MCP Client classes for tool-calling environments. This module provides client classes for interacting with MCP-enabled environments: - MCPClientBase: Base class with shared tool discovery - MCPToolClient: Client for tool-calling style (one tool per step) These clients abstract away the MCP protocol details, providing a clean interface for listing and calling tools on remote environments. Example: >>> from openenv.core.mcp_client import MCPToolClient >>> >>> with MCPToolClient(base_url="http://localhost:8000") as env: ... # Discover available tools ... tools = env.list_tools() ... print([t.name for t in tools]) ... ... # Call a tool ... result = env.call_tool("echo_message", message="Hello!") ... print(result) """ from typing import Any, Dict, List, Optional from .client_types import StepResult, StateT from .env_client import EnvClient from .env_server.mcp_types import ( CallToolAction, CallToolObservation, ListToolsAction, ListToolsObservation, Tool, ToolError, ) from .env_server.types import Observation, State class MCPClientBase(EnvClient[Any, Observation, State]): """ Base class for MCP clients with tool discovery. This class provides the common `list_tools()` method for discovering available tools from an MCP-enabled environment. Subclasses implement specific interaction patterns (tool-calling or CodeAct). Attributes: _tools_cache: Cached list of tools (populated on first `list_tools()` call) """ def __init__( self, base_url: str, connect_timeout_s: float = 10.0, message_timeout_s: float = 60.0, provider: Optional[Any] = None, ): """ Initialize MCP client. Args: base_url: Base URL of the environment server (http:// or ws://). connect_timeout_s: Timeout for establishing WebSocket connection. message_timeout_s: Timeout for receiving responses to messages. provider: Optional container/runtime provider for lifecycle management. """ super().__init__( base_url=base_url, connect_timeout_s=connect_timeout_s, message_timeout_s=message_timeout_s, provider=provider, ) self._tools_cache: Optional[List[Tool]] = None def list_tools(self, use_cache: bool = True) -> List[Tool]: """ Discover available tools from the environment. Args: use_cache: If True, return cached tools if available. Set to False to force a fresh request. Returns: List of Tool objects with name, description, and input_schema. Example: >>> tools = env.list_tools() >>> for tool in tools: ... print(f"{tool.name}: {tool.description}") """ if use_cache and self._tools_cache is not None: return self._tools_cache result = self.step(ListToolsAction()) self._tools_cache = result.observation.tools return self._tools_cache def _step_payload(self, action: Any) -> Dict[str, Any]: """Convert an Action object to the JSON data expected by the env server.""" if isinstance(action, ListToolsAction): return {"type": "list_tools"} elif isinstance(action, CallToolAction): return { "type": "call_tool", "tool_name": action.tool_name, "arguments": action.arguments, } else: # For unknown actions, try to serialize as dict if hasattr(action, "model_dump"): return action.model_dump() return {"action": str(action)} def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: """Convert a JSON response from the env server to StepResult[Observation].""" obs_data = payload.get("observation", {}) # Check if this is a ListToolsObservation if "tools" in obs_data: tools = [ Tool( name=t.get("name", ""), description=t.get("description", ""), input_schema=t.get("input_schema", t.get("inputSchema", {})), ) for t in obs_data.get("tools", []) ] observation = ListToolsObservation( tools=tools, done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) # Check if this is a CallToolObservation elif "tool_name" in obs_data: error = None if obs_data.get("error"): error = ToolError(**obs_data["error"]) observation = CallToolObservation( tool_name=obs_data.get("tool_name", ""), result=obs_data.get("result"), error=error, done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) else: # Generic observation observation = Observation( done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> State: """Convert a JSON response from the state endpoint to a State object.""" return State( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), ) class MCPToolClient(MCPClientBase): """ Client for tool-calling style MCP interactions. Each step invokes a single tool. Use this for traditional function-calling agent patterns where the agent decides which tool to call next. This client provides convenience methods for tool discovery and invocation: - `list_tools()`: Get all available tools with their schemas - `call_tool(name, **kwargs)`: Invoke a tool by name with arguments Example: >>> with MCPToolClient(base_url="http://localhost:8000") as env: ... # Reset the environment ... env.reset() ... ... # Discover available tools ... tools = env.list_tools() ... print([t.name for t in tools]) # ['echo_message', 'echo_with_length'] ... ... # Call a tool directly ... result = env.call_tool("echo_message", message="Hello!") ... print(result) # "Hello!" ... ... # Or use the full action interface ... from openenv.core.env_server.mcp_types import CallToolAction ... step_result = env.step(CallToolAction( ... tool_name="echo_with_length", ... arguments={"message": "Test"} ... )) ... print(step_result.observation.result) """ def call_tool(self, name: str, **kwargs: Any) -> Any: """ Call a tool by name. This is a convenience method that creates a CallToolAction, executes it, and returns the result directly. For more control, use `step()` with a CallToolAction directly. Args: name: Name of the tool to invoke (must match a tool from `list_tools()`). **kwargs: Arguments to pass to the tool. Must match the tool's input_schema. Returns: The tool's result. The type depends on the tool being called. Raises: RuntimeError: If the server returns an error response. Example: >>> result = env.call_tool("add", a=5, b=3) >>> print(result) # 8 >>> >>> result = env.call_tool("greet", name="Claude") >>> print(result) # "Hello, Claude!" """ action = CallToolAction(tool_name=name, arguments=kwargs) result = self.step(action) obs = result.observation # Check for transport/framework errors if isinstance(obs, CallToolObservation) and obs.error is not None: raise RuntimeError( f"Tool '{name}' failed: {obs.error.message} " f"(type: {obs.error.error_type.value})" ) # Return the result if isinstance(obs, CallToolObservation): result = obs.result # Handle FastMCP CallToolResult objects # - As object: has .data attribute # - As dict (from JSON): has "data" key if hasattr(result, "data"): return result.data if isinstance(result, dict) and "data" in result: return result["data"] return result # Fallback for unexpected observation types return obs def get_tool(self, name: str) -> Optional[Tool]: """ Get a specific tool by name. Args: name: Name of the tool to find. Returns: The Tool object if found, None otherwise. Example: >>> tool = env.get_tool("echo_message") >>> if tool: ... print(tool.description) ... print(tool.input_schema) """ tools = self.list_tools() for tool in tools: if tool.name == name: return tool return None def has_tool(self, name: str) -> bool: """ Check if a tool exists. Args: name: Name of the tool to check. Returns: True if the tool exists, False otherwise. """ return self.get_tool(name) is not None