Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| MCP Environment base class for OpenEnv. | |
| This module provides the MCPEnvironment base class that integrates FastMCP servers | |
| with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery | |
| and invocation through the step() API, following RFC 003. | |
| Key features: | |
| - Automatic routing of ListToolsAction and CallToolAction to MCP server | |
| - Reserved tool name validation (reset, step, state, close are protected) | |
| - Timeout handling for tool calls | |
| - Proper error categorization (tool not found, execution errors, timeouts) | |
| - Mode-aware tool registration (production vs simulation) | |
| - Code mode support via get_callables() and execute_code() | |
| Usage: | |
| from fastmcp import FastMCP | |
| from openenv.core.env_server.mcp_environment import MCPEnvironment | |
| class MyMCPEnv(MCPEnvironment): | |
| def __init__(self): | |
| mcp = FastMCP("my-server") | |
| # Register mode-specific tools | |
| @self.tool(mode="production") | |
| def my_tool(arg: str) -> str: | |
| return f"Production: {arg}" | |
| @self.tool(mode="simulation") | |
| def my_tool(arg: str) -> str: | |
| return f"Simulation: {arg}" | |
| super().__init__(mcp) | |
| def reset(self, seed=None, episode_id=None, **kwargs): | |
| # Reset logic here | |
| ... | |
| def _step_impl(self, action): | |
| # Handle non-MCP actions | |
| ... | |
| @property | |
| def state(self): | |
| # Return current state | |
| ... | |
| """ | |
| import asyncio | |
| import inspect | |
| from abc import abstractmethod | |
| from collections import defaultdict | |
| from typing import Any, Callable, Dict, Optional | |
| from fastmcp import Client | |
| from fastmcp.client.client import CallToolResult | |
| from mcp.types import TextContent | |
| from ..utils import run_async_safely | |
| from .interfaces import Environment | |
| from .mcp_types import ( | |
| CallToolAction, | |
| CallToolObservation, | |
| ListToolsAction, | |
| ListToolsObservation, | |
| RESERVED_TOOL_NAMES, | |
| Tool, | |
| ToolError, | |
| ToolErrorType, | |
| ) | |
| from .types import Action, Observation | |
| # Default timeout for MCP tool calls in seconds | |
| MCP_TOOL_CALL_TIMEOUT = 30.0 | |
| # Valid modes for tool registration | |
| VALID_MODES = {"production", "simulation"} | |
| def get_server_tools(mcp_server: Any) -> Dict[str, Any]: | |
| """ | |
| Get tools from a FastMCP server, compatible with both 2.x and 3.x. | |
| Returns: | |
| Dictionary mapping tool names to tool objects. | |
| """ | |
| # FastMCP 2.x: get_tools() returns dict {name: Tool} | |
| if hasattr(mcp_server, "get_tools"): | |
| result = run_async_safely(mcp_server.get_tools()) | |
| if isinstance(result, dict): | |
| return result | |
| # FastMCP 3.x: list_tools() returns list of Tool objects | |
| if hasattr(mcp_server, "list_tools"): | |
| tools_list = run_async_safely(mcp_server.list_tools()) | |
| return {t.name: t for t in tools_list} | |
| return {} | |
| class MCPEnvironment(Environment): | |
| """ | |
| Base class for environments that expose tools via MCP (Model Context Protocol). | |
| MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing | |
| agents to discover and invoke MCP tools through the standard step() interface. | |
| The class automatically handles: | |
| - ListToolsAction: Returns available tools from the MCP server | |
| - CallToolAction: Invokes a specific tool with arguments | |
| All other actions are delegated to the abstract _step_impl() method, | |
| which subclasses must implement. | |
| Args: | |
| mcp_server: A FastMCP server instance containing tool definitions. | |
| The server's tools will be validated against reserved names. | |
| transform: Optional transform to apply to observations (inherited from Environment). | |
| Raises: | |
| ValueError: If any tool in the MCP server uses a reserved name | |
| (reset, step, state, close). | |
| Example: | |
| >>> from fastmcp import FastMCP | |
| >>> mcp = FastMCP("calculator") | |
| >>> @mcp.tool() | |
| ... def add(a: int, b: int) -> int: | |
| ... return a + b | |
| >>> env = MyMCPEnvironment(mcp) | |
| >>> obs = env.step(ListToolsAction()) | |
| >>> obs.tools[0].name | |
| 'add' | |
| """ | |
| def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None: | |
| """ | |
| Initialize the MCP environment. | |
| Args: | |
| mcp_server: A FastMCP server instance with tool definitions. | |
| transform: Optional transform to apply to observations. | |
| Raises: | |
| ValueError: If any tool uses a reserved name (reset, step, state, close). | |
| """ | |
| super().__init__(transform=transform) | |
| # Validate tool names before storing | |
| self._validate_tool_names(mcp_server) | |
| self.mcp_server = mcp_server | |
| self.mcp_client = Client(mcp_server) | |
| # Track mode-specific tools: {tool_name: {mode: func}} | |
| # mode can be "production", "simulation", or None (available in all modes) | |
| self._mode_tools = defaultdict(dict) | |
| # Track tool schemas for list_tools: {tool_name: {mode: schema}} | |
| self._mode_tool_schemas = defaultdict(dict) | |
| def supports_code_mode(self) -> bool: | |
| """Check if this environment supports code mode (execute_code).""" | |
| return True | |
| def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]: | |
| """ | |
| Get tools from a FastMCP server, compatible with both 2.x and 3.x. | |
| Returns: | |
| Dictionary mapping tool names to tool objects. | |
| """ | |
| return get_server_tools(mcp_server) | |
| def get_callables(self) -> Dict[str, Callable]: | |
| """ | |
| Get callable functions for code mode. | |
| Returns tool functions as direct Python callables, enabling code mode | |
| where agents write Python code that calls tools directly (no JSON-RPC | |
| overhead). Mode-specific tools are filtered by the current mode. | |
| Returns: | |
| Dictionary mapping tool names to callables. | |
| """ | |
| callables: Dict[str, Callable] = {} | |
| current_mode = getattr(self, "_mode", None) | |
| # Extract callables from FastMCP server using public API | |
| for tool_name, tool in self._get_server_tools(self.mcp_server).items(): | |
| if hasattr(tool, "fn") and callable(tool.fn): | |
| callables[tool_name] = tool.fn | |
| # Add mode-specific tools available in current mode | |
| for tool_name, mode_funcs in self._mode_tools.items(): | |
| if None in mode_funcs: | |
| # Tool available in all modes (already in FastMCP if registered there) | |
| if tool_name not in callables: | |
| callables[tool_name] = mode_funcs[None] | |
| elif current_mode in mode_funcs: | |
| # Tool available in current mode only | |
| callables[tool_name] = mode_funcs[current_mode] | |
| return callables | |
| def execute_code(self, code: str) -> Observation: | |
| """ | |
| Execute Python code with tools available as callables. | |
| This enables the CodeAct pattern where agents write Python code | |
| that calls tools directly as functions, avoiding JSON-RPC overhead. | |
| Args: | |
| code: Python code to execute. Tools are available as functions | |
| in the execution namespace. Set a variable named 'result' | |
| to capture the return value. | |
| Returns: | |
| Observation with result in metadata["result"] or error in | |
| metadata["error"]. | |
| """ | |
| namespace = self.get_callables() | |
| result_dict: Dict[str, Any] = {} | |
| try: | |
| exec(code, namespace, result_dict) | |
| result = result_dict.get("result") | |
| return Observation(done=False, reward=0.0, metadata={"result": result}) | |
| except SyntaxError as e: | |
| return Observation( | |
| done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"} | |
| ) | |
| except Exception as e: | |
| return Observation(done=False, reward=0.0, metadata={"error": str(e)}) | |
| def _validate_tool_names(self, mcp_server: Any) -> None: | |
| """ | |
| Validate that no tools use reserved names. | |
| Reserved names (reset, step, state, close) are protected to maintain | |
| the dual API boundary between infrastructure and agent APIs. | |
| Args: | |
| mcp_server: The FastMCP server to validate. | |
| Raises: | |
| ValueError: If any tool uses a reserved name. | |
| """ | |
| tools_dict = self._get_server_tools(mcp_server) | |
| if tools_dict: | |
| tool_names = set(tools_dict.keys()) | |
| conflicts = tool_names & RESERVED_TOOL_NAMES | |
| if conflicts: | |
| raise ValueError( | |
| f"MCP tools cannot use reserved names: {sorted(conflicts)}. " | |
| f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" | |
| ) | |
| def tool(self, mode: Optional[str] = None) -> Callable: | |
| """ | |
| Decorator for registering mode-aware tools. | |
| Args: | |
| mode: Optional mode for the tool ("production" or "simulation"). | |
| If None, tool is available in all modes. | |
| Returns: | |
| A decorator function for registering tools. | |
| Raises: | |
| ValueError: If mode is not None, "production", or "simulation". | |
| """ | |
| if mode is not None and mode not in VALID_MODES: | |
| raise ValueError( | |
| f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None." | |
| ) | |
| def decorator(func: Callable) -> Callable: | |
| tool_name = func.__name__ | |
| # Validate tool name is not reserved | |
| if tool_name in RESERVED_TOOL_NAMES: | |
| raise ValueError( | |
| f"Tool name '{tool_name}' is reserved and cannot be used. " | |
| f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" | |
| ) | |
| # If mode is None, register with FastMCP as usual | |
| if mode is None: | |
| decorated_func = self.mcp_server.tool()(func) | |
| self._mode_tools[tool_name][None] = func | |
| return decorated_func | |
| # For mode-specific tools, don't register with FastMCP | |
| # Instead, track them ourselves | |
| self._mode_tools[tool_name][mode] = func | |
| # Extract schema information from function signature | |
| sig = inspect.signature(func) | |
| schema = { | |
| "type": "object", | |
| "properties": {}, | |
| "required": [], | |
| } | |
| for param_name, param in sig.parameters.items(): | |
| # Get type annotation | |
| param_type = param.annotation | |
| json_type = "string" # default | |
| if param_type in (int, "int"): | |
| json_type = "integer" | |
| elif param_type in (float, "float"): | |
| json_type = "number" | |
| elif param_type in (bool, "bool"): | |
| json_type = "boolean" | |
| schema["properties"][param_name] = {"type": json_type} | |
| # If no default value, it's required | |
| if param.default == inspect.Parameter.empty: | |
| schema["required"].append(param_name) | |
| # Store the schema for this mode-specific tool | |
| self._mode_tool_schemas[tool_name][mode] = { | |
| "name": tool_name, | |
| "description": func.__doc__ or "", | |
| "input_schema": schema, | |
| } | |
| return func | |
| return decorator | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| """ | |
| Execute an action in the environment. | |
| This method routes MCP-specific actions (ListToolsAction, CallToolAction) | |
| to the appropriate handlers, while delegating all other actions to | |
| the subclass's _step_impl() method. | |
| Args: | |
| action: The action to execute. Can be: | |
| - ListToolsAction: Returns available MCP tools | |
| - CallToolAction: Invokes a specific MCP tool | |
| - Any other Action: Delegated to _step_impl() | |
| timeout_s: Optional timeout in seconds for the action. | |
| Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions. | |
| **kwargs: Additional arguments passed to handlers. | |
| Returns: | |
| Observation appropriate to the action type: | |
| - ListToolsObservation for ListToolsAction | |
| - CallToolObservation for CallToolAction | |
| - Subclass-defined Observation for other actions | |
| """ | |
| if isinstance(action, ListToolsAction): | |
| return self._handle_list_tools() | |
| elif isinstance(action, CallToolAction): | |
| return self._handle_call_tool(action, timeout_s=timeout_s) | |
| else: | |
| return self._step_impl(action, timeout_s=timeout_s, **kwargs) | |
| def _handle_list_tools(self) -> ListToolsObservation: | |
| """ | |
| Handle a ListToolsAction by querying the MCP server. | |
| Returns: | |
| ListToolsObservation containing all available tools with their | |
| names, descriptions, and input schemas, filtered by current mode. | |
| """ | |
| try: | |
| # Get current mode | |
| current_mode = getattr(self, "_mode", None) | |
| # Start with tools from FastMCP server (mode=None tools) | |
| tools_result = run_async_safely(self._async_list_tools()) | |
| # Build list of Tool objects | |
| tools = [] | |
| # Add FastMCP tools that are not mode-specific | |
| for tool in tools_result: | |
| if tool.name not in self._mode_tool_schemas: | |
| tools.append( | |
| Tool( | |
| name=tool.name, | |
| description=tool.description or "", | |
| input_schema=tool.inputSchema | |
| if hasattr(tool, "inputSchema") | |
| else {}, | |
| ) | |
| ) | |
| # Add mode-specific tools available in current mode | |
| for tool_name, mode_schemas in self._mode_tool_schemas.items(): | |
| if None in mode_schemas: | |
| # Tool available in all modes | |
| schema = mode_schemas[None] | |
| tools.append( | |
| Tool( | |
| name=schema["name"], | |
| description=schema["description"], | |
| input_schema=schema["input_schema"], | |
| ) | |
| ) | |
| elif current_mode in mode_schemas: | |
| # Tool available in current mode | |
| schema = mode_schemas[current_mode] | |
| tools.append( | |
| Tool( | |
| name=schema["name"], | |
| description=schema["description"], | |
| input_schema=schema["input_schema"], | |
| ) | |
| ) | |
| return ListToolsObservation(tools=tools) | |
| except Exception as e: | |
| # Return an observation with error in metadata | |
| return ListToolsObservation( | |
| tools=[], | |
| metadata={ | |
| "error": str(e), | |
| "error_type": "list_tools_failed", | |
| }, | |
| ) | |
| async def _async_list_tools(self) -> list: | |
| """ | |
| Async helper to list tools from the MCP client. | |
| Returns: | |
| List of tool objects from the MCP server. | |
| """ | |
| async with self.mcp_client: | |
| return await self.mcp_client.list_tools() | |
| def _handle_call_tool( | |
| self, | |
| action: CallToolAction, | |
| timeout_s: Optional[float] = None, | |
| ) -> CallToolObservation: | |
| """ | |
| Handle a CallToolAction by invoking the specified tool. | |
| Args: | |
| action: The CallToolAction containing tool_name and arguments. | |
| timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). | |
| Returns: | |
| CallToolObservation with the tool's result or an error. | |
| """ | |
| timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT | |
| # Check if this is a mode-specific tool | |
| tool_name = action.tool_name | |
| current_mode = getattr(self, "_mode", None) | |
| if tool_name in self._mode_tools: | |
| mode_info = self._mode_tools[tool_name] | |
| # Check if tool is available in current mode | |
| # Tool is available if: | |
| # 1. It has a None mode (available in all modes), OR | |
| # 2. It has an implementation for the current mode | |
| if None in mode_info: | |
| # Use the mode-agnostic version | |
| func = mode_info[None] | |
| elif current_mode in mode_info: | |
| # Use the mode-specific version | |
| func = mode_info[current_mode] | |
| else: | |
| # Tool not available in current mode | |
| return CallToolObservation( | |
| tool_name=tool_name, | |
| result=None, | |
| error=ToolError( | |
| error_type=ToolErrorType.TOOL_NOT_FOUND, | |
| message=f"Tool '{tool_name}' not available in {current_mode} mode", | |
| ), | |
| ) | |
| # Call the mode-specific function directly | |
| try: | |
| # Check if function is async and await if necessary | |
| if inspect.iscoroutinefunction(func): | |
| result = run_async_safely(func(**action.arguments)) | |
| else: | |
| result = func(**action.arguments) | |
| # Wrap result in CallToolResult format to match FastMCP behavior | |
| return CallToolObservation( | |
| tool_name=tool_name, | |
| result=CallToolResult( | |
| content=[TextContent(type="text", text=str(result))], | |
| structured_content={"result": result}, | |
| meta=None, | |
| data=result, | |
| is_error=False, | |
| ), | |
| ) | |
| except Exception as e: | |
| return CallToolObservation( | |
| tool_name=tool_name, | |
| result=None, | |
| error=ToolError( | |
| error_type=ToolErrorType.EXECUTION_ERROR, | |
| message=str(e), | |
| ), | |
| ) | |
| # Not a mode-specific tool, use FastMCP | |
| try: | |
| # Run the async call_tool with timeout | |
| # Use run_async_safely to handle both sync and async contexts | |
| result = run_async_safely( | |
| asyncio.wait_for( | |
| self._async_call_tool(action.tool_name, action.arguments), | |
| timeout=timeout, | |
| ) | |
| ) | |
| return CallToolObservation( | |
| tool_name=action.tool_name, | |
| result=result, | |
| ) | |
| except asyncio.TimeoutError: | |
| return CallToolObservation( | |
| tool_name=action.tool_name, | |
| result=None, | |
| error=ToolError( | |
| error_type=ToolErrorType.TIMEOUT, | |
| message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", | |
| ), | |
| ) | |
| except Exception as e: | |
| error_message = str(e) | |
| # Determine error type based on the exception | |
| if ( | |
| "not found" in error_message.lower() | |
| or "unknown tool" in error_message.lower() | |
| ): | |
| error_type = ToolErrorType.TOOL_NOT_FOUND | |
| elif ( | |
| "invalid" in error_message.lower() | |
| or "argument" in error_message.lower() | |
| ): | |
| error_type = ToolErrorType.INVALID_ARGS | |
| else: | |
| error_type = ToolErrorType.EXECUTION_ERROR | |
| return CallToolObservation( | |
| tool_name=action.tool_name, | |
| result=None, | |
| error=ToolError( | |
| error_type=error_type, | |
| message=error_message, | |
| ), | |
| ) | |
| async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: | |
| """ | |
| Async helper to call a tool on the MCP server. | |
| Args: | |
| tool_name: Name of the tool to invoke. | |
| arguments: Dictionary of arguments to pass to the tool. | |
| Returns: | |
| The result from the tool execution. | |
| """ | |
| async with self.mcp_client: | |
| return await self.mcp_client.call_tool(tool_name, arguments) | |
| def _step_impl( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| """ | |
| Handle non-MCP actions in the environment. | |
| Subclasses must implement this method to handle any actions that are | |
| not ListToolsAction or CallToolAction. This is where environment-specific | |
| action processing should occur. | |
| Args: | |
| action: The action to execute (guaranteed not to be an MCP action). | |
| timeout_s: Optional timeout in seconds. | |
| **kwargs: Additional arguments. | |
| Returns: | |
| An Observation appropriate for the action. | |
| """ | |
| pass | |
| def close(self) -> None: | |
| """ | |
| Clean up resources used by the environment. | |
| This method cleans up the MCP client and any other resources. | |
| Subclasses should call super().close() if they override this method. | |
| """ | |
| # The MCP client uses async context manager, so cleanup happens | |
| # automatically when the context exits. We just clear references. | |
| self.mcp_client = None | |
| self.mcp_server = None | |