# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ MCP Environment base class for OpenEnv. This module provides the MCPEnvironment base class that integrates FastMCP servers with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery and invocation through the step() API, following RFC 003. Key features: - Automatic routing of ListToolsAction and CallToolAction to MCP server - Reserved tool name validation (reset, step, state, close are protected) - Timeout handling for tool calls - Proper error categorization (tool not found, execution errors, timeouts) - Mode-aware tool registration (production vs simulation) - Code mode support via get_callables() and execute_code() Usage: from fastmcp import FastMCP from openenv.core.env_server.mcp_environment import MCPEnvironment class MyMCPEnv(MCPEnvironment): def __init__(self): mcp = FastMCP("my-server") # Register mode-specific tools @self.tool(mode="production") def my_tool(arg: str) -> str: return f"Production: {arg}" @self.tool(mode="simulation") def my_tool(arg: str) -> str: return f"Simulation: {arg}" super().__init__(mcp) def reset(self, seed=None, episode_id=None, **kwargs): # Reset logic here ... def _step_impl(self, action): # Handle non-MCP actions ... @property def state(self): # Return current state ... """ import asyncio import inspect from abc import abstractmethod from collections import defaultdict from typing import Any, Callable, Dict, Optional from fastmcp import Client from fastmcp.client.client import CallToolResult from mcp.types import TextContent from ..utils import run_async_safely from .interfaces import Environment from .mcp_types import ( CallToolAction, CallToolObservation, ListToolsAction, ListToolsObservation, RESERVED_TOOL_NAMES, Tool, ToolError, ToolErrorType, ) from .types import Action, Observation # Default timeout for MCP tool calls in seconds MCP_TOOL_CALL_TIMEOUT = 30.0 # Valid modes for tool registration VALID_MODES = {"production", "simulation"} def get_server_tools(mcp_server: Any) -> Dict[str, Any]: """ Get tools from a FastMCP server, compatible with both 2.x and 3.x. Returns: Dictionary mapping tool names to tool objects. """ # FastMCP 2.x: get_tools() returns dict {name: Tool} if hasattr(mcp_server, "get_tools"): result = run_async_safely(mcp_server.get_tools()) if isinstance(result, dict): return result # FastMCP 3.x: list_tools() returns list of Tool objects if hasattr(mcp_server, "list_tools"): tools_list = run_async_safely(mcp_server.list_tools()) return {t.name: t for t in tools_list} return {} class MCPEnvironment(Environment): """ Base class for environments that expose tools via MCP (Model Context Protocol). MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing agents to discover and invoke MCP tools through the standard step() interface. The class automatically handles: - ListToolsAction: Returns available tools from the MCP server - CallToolAction: Invokes a specific tool with arguments All other actions are delegated to the abstract _step_impl() method, which subclasses must implement. Args: mcp_server: A FastMCP server instance containing tool definitions. The server's tools will be validated against reserved names. transform: Optional transform to apply to observations (inherited from Environment). Raises: ValueError: If any tool in the MCP server uses a reserved name (reset, step, state, close). Example: >>> from fastmcp import FastMCP >>> mcp = FastMCP("calculator") >>> @mcp.tool() ... def add(a: int, b: int) -> int: ... return a + b >>> env = MyMCPEnvironment(mcp) >>> obs = env.step(ListToolsAction()) >>> obs.tools[0].name 'add' """ def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None: """ Initialize the MCP environment. Args: mcp_server: A FastMCP server instance with tool definitions. transform: Optional transform to apply to observations. Raises: ValueError: If any tool uses a reserved name (reset, step, state, close). """ super().__init__(transform=transform) # Validate tool names before storing self._validate_tool_names(mcp_server) self.mcp_server = mcp_server self.mcp_client = Client(mcp_server) # Track mode-specific tools: {tool_name: {mode: func}} # mode can be "production", "simulation", or None (available in all modes) self._mode_tools = defaultdict(dict) # Track tool schemas for list_tools: {tool_name: {mode: schema}} self._mode_tool_schemas = defaultdict(dict) @property def supports_code_mode(self) -> bool: """Check if this environment supports code mode (execute_code).""" return True def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]: """ Get tools from a FastMCP server, compatible with both 2.x and 3.x. Returns: Dictionary mapping tool names to tool objects. """ return get_server_tools(mcp_server) def get_callables(self) -> Dict[str, Callable]: """ Get callable functions for code mode. Returns tool functions as direct Python callables, enabling code mode where agents write Python code that calls tools directly (no JSON-RPC overhead). Mode-specific tools are filtered by the current mode. Returns: Dictionary mapping tool names to callables. """ callables: Dict[str, Callable] = {} current_mode = getattr(self, "_mode", None) # Extract callables from FastMCP server using public API for tool_name, tool in self._get_server_tools(self.mcp_server).items(): if hasattr(tool, "fn") and callable(tool.fn): callables[tool_name] = tool.fn # Add mode-specific tools available in current mode for tool_name, mode_funcs in self._mode_tools.items(): if None in mode_funcs: # Tool available in all modes (already in FastMCP if registered there) if tool_name not in callables: callables[tool_name] = mode_funcs[None] elif current_mode in mode_funcs: # Tool available in current mode only callables[tool_name] = mode_funcs[current_mode] return callables def execute_code(self, code: str) -> Observation: """ Execute Python code with tools available as callables. This enables the CodeAct pattern where agents write Python code that calls tools directly as functions, avoiding JSON-RPC overhead. Args: code: Python code to execute. Tools are available as functions in the execution namespace. Set a variable named 'result' to capture the return value. Returns: Observation with result in metadata["result"] or error in metadata["error"]. """ namespace = self.get_callables() result_dict: Dict[str, Any] = {} try: exec(code, namespace, result_dict) result = result_dict.get("result") return Observation(done=False, reward=0.0, metadata={"result": result}) except SyntaxError as e: return Observation( done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"} ) except Exception as e: return Observation(done=False, reward=0.0, metadata={"error": str(e)}) def _validate_tool_names(self, mcp_server: Any) -> None: """ Validate that no tools use reserved names. Reserved names (reset, step, state, close) are protected to maintain the dual API boundary between infrastructure and agent APIs. Args: mcp_server: The FastMCP server to validate. Raises: ValueError: If any tool uses a reserved name. """ tools_dict = self._get_server_tools(mcp_server) if tools_dict: tool_names = set(tools_dict.keys()) conflicts = tool_names & RESERVED_TOOL_NAMES if conflicts: raise ValueError( f"MCP tools cannot use reserved names: {sorted(conflicts)}. " f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" ) def tool(self, mode: Optional[str] = None) -> Callable: """ Decorator for registering mode-aware tools. Args: mode: Optional mode for the tool ("production" or "simulation"). If None, tool is available in all modes. Returns: A decorator function for registering tools. Raises: ValueError: If mode is not None, "production", or "simulation". """ if mode is not None and mode not in VALID_MODES: raise ValueError( f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None." ) def decorator(func: Callable) -> Callable: tool_name = func.__name__ # Validate tool name is not reserved if tool_name in RESERVED_TOOL_NAMES: raise ValueError( f"Tool name '{tool_name}' is reserved and cannot be used. " f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" ) # If mode is None, register with FastMCP as usual if mode is None: decorated_func = self.mcp_server.tool()(func) self._mode_tools[tool_name][None] = func return decorated_func # For mode-specific tools, don't register with FastMCP # Instead, track them ourselves self._mode_tools[tool_name][mode] = func # Extract schema information from function signature sig = inspect.signature(func) schema = { "type": "object", "properties": {}, "required": [], } for param_name, param in sig.parameters.items(): # Get type annotation param_type = param.annotation json_type = "string" # default if param_type in (int, "int"): json_type = "integer" elif param_type in (float, "float"): json_type = "number" elif param_type in (bool, "bool"): json_type = "boolean" schema["properties"][param_name] = {"type": json_type} # If no default value, it's required if param.default == inspect.Parameter.empty: schema["required"].append(param_name) # Store the schema for this mode-specific tool self._mode_tool_schemas[tool_name][mode] = { "name": tool_name, "description": func.__doc__ or "", "input_schema": schema, } return func return decorator def step( self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Observation: """ Execute an action in the environment. This method routes MCP-specific actions (ListToolsAction, CallToolAction) to the appropriate handlers, while delegating all other actions to the subclass's _step_impl() method. Args: action: The action to execute. Can be: - ListToolsAction: Returns available MCP tools - CallToolAction: Invokes a specific MCP tool - Any other Action: Delegated to _step_impl() timeout_s: Optional timeout in seconds for the action. Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions. **kwargs: Additional arguments passed to handlers. Returns: Observation appropriate to the action type: - ListToolsObservation for ListToolsAction - CallToolObservation for CallToolAction - Subclass-defined Observation for other actions """ if isinstance(action, ListToolsAction): return self._handle_list_tools() elif isinstance(action, CallToolAction): return self._handle_call_tool(action, timeout_s=timeout_s) else: return self._step_impl(action, timeout_s=timeout_s, **kwargs) def _handle_list_tools(self) -> ListToolsObservation: """ Handle a ListToolsAction by querying the MCP server. Returns: ListToolsObservation containing all available tools with their names, descriptions, and input schemas, filtered by current mode. """ try: # Get current mode current_mode = getattr(self, "_mode", None) # Start with tools from FastMCP server (mode=None tools) tools_result = run_async_safely(self._async_list_tools()) # Build list of Tool objects tools = [] # Add FastMCP tools that are not mode-specific for tool in tools_result: if tool.name not in self._mode_tool_schemas: tools.append( Tool( name=tool.name, description=tool.description or "", input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {}, ) ) # Add mode-specific tools available in current mode for tool_name, mode_schemas in self._mode_tool_schemas.items(): if None in mode_schemas: # Tool available in all modes schema = mode_schemas[None] tools.append( Tool( name=schema["name"], description=schema["description"], input_schema=schema["input_schema"], ) ) elif current_mode in mode_schemas: # Tool available in current mode schema = mode_schemas[current_mode] tools.append( Tool( name=schema["name"], description=schema["description"], input_schema=schema["input_schema"], ) ) return ListToolsObservation(tools=tools) except Exception as e: # Return an observation with error in metadata return ListToolsObservation( tools=[], metadata={ "error": str(e), "error_type": "list_tools_failed", }, ) async def _async_list_tools(self) -> list: """ Async helper to list tools from the MCP client. Returns: List of tool objects from the MCP server. """ async with self.mcp_client: return await self.mcp_client.list_tools() def _handle_call_tool( self, action: CallToolAction, timeout_s: Optional[float] = None, ) -> CallToolObservation: """ Handle a CallToolAction by invoking the specified tool. Args: action: The CallToolAction containing tool_name and arguments. timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). Returns: CallToolObservation with the tool's result or an error. """ timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT # Check if this is a mode-specific tool tool_name = action.tool_name current_mode = getattr(self, "_mode", None) if tool_name in self._mode_tools: mode_info = self._mode_tools[tool_name] # Check if tool is available in current mode # Tool is available if: # 1. It has a None mode (available in all modes), OR # 2. It has an implementation for the current mode if None in mode_info: # Use the mode-agnostic version func = mode_info[None] elif current_mode in mode_info: # Use the mode-specific version func = mode_info[current_mode] else: # Tool not available in current mode return CallToolObservation( tool_name=tool_name, result=None, error=ToolError( error_type=ToolErrorType.TOOL_NOT_FOUND, message=f"Tool '{tool_name}' not available in {current_mode} mode", ), ) # Call the mode-specific function directly try: # Check if function is async and await if necessary if inspect.iscoroutinefunction(func): result = run_async_safely(func(**action.arguments)) else: result = func(**action.arguments) # Wrap result in CallToolResult format to match FastMCP behavior return CallToolObservation( tool_name=tool_name, result=CallToolResult( content=[TextContent(type="text", text=str(result))], structured_content={"result": result}, meta=None, data=result, is_error=False, ), ) except Exception as e: return CallToolObservation( tool_name=tool_name, result=None, error=ToolError( error_type=ToolErrorType.EXECUTION_ERROR, message=str(e), ), ) # Not a mode-specific tool, use FastMCP try: # Run the async call_tool with timeout # Use run_async_safely to handle both sync and async contexts result = run_async_safely( asyncio.wait_for( self._async_call_tool(action.tool_name, action.arguments), timeout=timeout, ) ) return CallToolObservation( tool_name=action.tool_name, result=result, ) except asyncio.TimeoutError: return CallToolObservation( tool_name=action.tool_name, result=None, error=ToolError( error_type=ToolErrorType.TIMEOUT, message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", ), ) except Exception as e: error_message = str(e) # Determine error type based on the exception if ( "not found" in error_message.lower() or "unknown tool" in error_message.lower() ): error_type = ToolErrorType.TOOL_NOT_FOUND elif ( "invalid" in error_message.lower() or "argument" in error_message.lower() ): error_type = ToolErrorType.INVALID_ARGS else: error_type = ToolErrorType.EXECUTION_ERROR return CallToolObservation( tool_name=action.tool_name, result=None, error=ToolError( error_type=error_type, message=error_message, ), ) async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: """ Async helper to call a tool on the MCP server. Args: tool_name: Name of the tool to invoke. arguments: Dictionary of arguments to pass to the tool. Returns: The result from the tool execution. """ async with self.mcp_client: return await self.mcp_client.call_tool(tool_name, arguments) @abstractmethod def _step_impl( self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any, ) -> Observation: """ Handle non-MCP actions in the environment. Subclasses must implement this method to handle any actions that are not ListToolsAction or CallToolAction. This is where environment-specific action processing should occur. Args: action: The action to execute (guaranteed not to be an MCP action). timeout_s: Optional timeout in seconds. **kwargs: Additional arguments. Returns: An Observation appropriate for the action. """ pass def close(self) -> None: """ Clean up resources used by the environment. This method cleans up the MCP client and any other resources. Subclasses should call super().close() if they override this method. """ # The MCP client uses async context manager, so cleanup happens # automatically when the context exits. We just clear references. self.mcp_client = None self.mcp_server = None