repl_env-v2-1-0 / src /core /mcp_client.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
eba9bbb verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
MCP Client classes for tool-calling environments.
This module provides async client classes for interacting with MCP-enabled environments:
- MCPClientBase: Base class with shared tool discovery
- MCPToolClient: Client for tool-calling style (one tool per step)
These clients abstract away the MCP protocol details, providing a clean interface
for listing and calling tools on remote environments. All clients are async by default.
Architecture Overview::
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ HTTPEnvServer โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Simulation Mode (default): โ”‚
โ”‚ /ws โ†’ OpenEnv protocol (reset/step/state) โ”‚
โ”‚ /mcp โ†’ MCP JSON-RPC (tools/list, tools/call) โ”‚
โ”‚ /reset, /step, /state โ†’ HTTP endpoints โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Production Mode (use_production_mode=True): โ”‚
โ”‚ /mcp โ†’ MCP JSON-RPC (tools/list, tools/call) โ”‚
โ”‚ Bypasses step() for direct tool access โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
Client Usage:
MCPToolClient (default) โ†’ /ws (step-based, with rewards)
MCPToolClient (production) โ†’ /mcp (direct tool access, no rewards)
Example (async):
>>> from openenv.core.mcp_client import MCPToolClient
>>>
>>> async with MCPToolClient(base_url="http://localhost:8000") as env:
... # Discover available tools
... tools = await env.list_tools()
... print([t.name for t in tools])
...
... # Call a tool
... result = await env.call_tool("echo_message", message="Hello!")
... print(result)
Example (sync wrapper):
>>> env = MCPToolClient(base_url="http://localhost:8000").sync()
>>> with env:
... tools = env.list_tools()
... result = env.call_tool("echo_message", message="Hello!")
"""
from typing import Any, Dict, List, Optional
from .client_types import StepResult
from .env_client import EnvClient
from .env_server.mcp_types import (
CallToolAction,
CallToolObservation,
ListToolsAction,
ListToolsObservation,
Tool,
ToolError,
)
from .env_server.types import Observation, State
class MCPClientBase(EnvClient[Any, Observation, State]):
"""
Base class for MCP clients with tool discovery.
This class provides the common `list_tools()` method for discovering
available tools from an MCP-enabled environment. Subclasses implement
specific interaction patterns (tool-calling or CodeAct).
Attributes:
_tools_cache: Cached list of tools (populated on first `list_tools()` call)
"""
def __init__(
self,
base_url: str,
connect_timeout_s: float = 10.0,
message_timeout_s: float = 60.0,
provider: Optional[Any] = None,
mode: Optional[str] = None,
):
"""
Initialize MCP client.
Args:
base_url: Base URL of the environment server (http:// or ws://).
connect_timeout_s: Timeout for establishing WebSocket connection.
message_timeout_s: Timeout for receiving responses to messages.
provider: Optional container/runtime provider for lifecycle management.
mode: Communication mode. Must be 'production' for MCP clients. Defaults to 'production'.
"""
# MCPClientBase defaults to production mode, but allow override for validation
if mode is None:
mode = "production"
# Validate that mode is production
mode_lower = mode.lower()
if mode_lower != "production":
raise ValueError(
f"MCPToolClient only supports 'production' mode, got '{mode}'. "
f"Use GenericEnvClient for simulation mode."
)
super().__init__(
base_url=base_url,
connect_timeout_s=connect_timeout_s,
message_timeout_s=message_timeout_s,
provider=provider,
mode=mode,
)
self._tools_cache: Optional[List[Tool]] = None
self.use_production_mode = False
async def list_tools(self, use_cache: bool = True) -> List[Tool]:
"""
Discover available tools from the environment.
Args:
use_cache: If True, return cached tools if available.
Set to False to force a fresh request.
Returns:
List of Tool objects with name, description, and input_schema.
Example:
>>> tools = await env.list_tools()
>>> for tool in tools:
... print(f"{tool.name}: {tool.description}")
"""
if use_cache and self._tools_cache is not None:
return self._tools_cache
# Use production mode HTTP endpoint if enabled
if self.use_production_mode:
import requests
# Convert ws:// URL to http:// URL
url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
# Remove /ws suffix if present and add /mcp
url = url.rstrip("/ws").rstrip("/") + "/mcp"
try:
response = requests.post(
url,
json={
"jsonrpc": "2.0",
"method": "tools/list",
"params": {},
"id": 1,
},
)
data = response.json()
if "result" in data and "tools" in data["result"]:
tools = [
Tool(
name=t.get("name", ""),
description=t.get("description", ""),
input_schema=t.get(
"input_schema", t.get("inputSchema", {})
),
)
for t in data["result"]["tools"]
]
self._tools_cache = tools
return tools
except Exception:
# If HTTP request fails, return empty list
pass
return []
result = await self.step(ListToolsAction())
self._tools_cache = result.observation.tools
return self._tools_cache
def _step_payload(self, action: Any) -> Dict[str, Any]:
"""Convert an Action object to the JSON data expected by the env server."""
if isinstance(action, ListToolsAction):
return {"type": "list_tools"}
elif isinstance(action, CallToolAction):
return {
"type": "call_tool",
"tool_name": action.tool_name,
"arguments": action.arguments,
}
else:
# For unknown actions, try to serialize as dict
if hasattr(action, "model_dump"):
return action.model_dump()
return {"action": str(action)}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]:
"""Convert a JSON response from the env server to StepResult[Observation]."""
obs_data = payload.get("observation", {})
# Check if this is a ListToolsObservation
if "tools" in obs_data:
tools = [
Tool(
name=t.get("name", ""),
description=t.get("description", ""),
input_schema=t.get("input_schema", t.get("inputSchema", {})),
)
for t in obs_data.get("tools", [])
]
observation = ListToolsObservation(
tools=tools,
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
# Check if this is a CallToolObservation
elif "tool_name" in obs_data:
error = None
if obs_data.get("error"):
error = ToolError(**obs_data["error"])
observation = CallToolObservation(
tool_name=obs_data.get("tool_name", ""),
result=obs_data.get("result"),
error=error,
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
else:
# Generic observation
observation = Observation(
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> State:
"""Convert a JSON response from the state endpoint to a State object."""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
class MCPToolClient(MCPClientBase):
"""
Async client for tool-calling style MCP interactions.
Each step invokes a single tool. Use this for traditional function-calling
agent patterns where the agent decides which tool to call next.
This client provides convenience methods for tool discovery and invocation:
- `list_tools()`: Get all available tools with their schemas
- `call_tool(name, **kwargs)`: Invoke a tool by name with arguments
Example (async):
>>> async with MCPToolClient(base_url="http://localhost:8000") as env:
... # Reset the environment
... await env.reset()
...
... # Discover available tools
... tools = await env.list_tools()
... print([t.name for t in tools]) # ['echo_message', 'echo_with_length']
...
... # Call a tool directly
... result = await env.call_tool("echo_message", message="Hello!")
... print(result) # "Hello!"
...
... # Or use the full action interface
... from openenv.core.env_server.mcp_types import CallToolAction
... step_result = await env.step(CallToolAction(
... tool_name="echo_with_length",
... arguments={"message": "Test"}
... ))
... print(step_result.observation.result)
Example (sync wrapper):
>>> env = MCPToolClient(base_url="http://localhost:8000").sync()
>>> with env:
... tools = env.list_tools()
... result = env.call_tool("echo_message", message="Hello!")
"""
async def call_tool(self, name: str, **kwargs: Any) -> Any:
"""
Call a tool by name.
This is a convenience method that creates a CallToolAction, executes it,
and returns the result directly. For more control, use `step()` with
a CallToolAction directly.
Args:
name: Name of the tool to invoke (must match a tool from `list_tools()`).
**kwargs: Arguments to pass to the tool. Must match the tool's input_schema.
Returns:
The tool's result. The type depends on the tool being called.
Raises:
RuntimeError: If the server returns an error response.
Example:
>>> result = await env.call_tool("add", a=5, b=3)
>>> print(result) # 8
>>>
>>> result = await env.call_tool("greet", name="Claude")
>>> print(result) # "Hello, Claude!"
"""
action = CallToolAction(tool_name=name, arguments=kwargs)
result = await self.step(action)
obs = result.observation
# Check for transport/framework errors
if isinstance(obs, CallToolObservation) and obs.error is not None:
raise RuntimeError(
f"Tool '{name}' failed: {obs.error.message} "
f"(type: {obs.error.error_type.value})"
)
# Return the result
if isinstance(obs, CallToolObservation):
result = obs.result
# Handle FastMCP CallToolResult objects
# - As object: has .data attribute
# - As dict (from JSON): has "data" key
if hasattr(result, "data"):
return result.data
if isinstance(result, dict) and "data" in result:
return result["data"]
return result
# Fallback for unexpected observation types
return obs
async def get_tool(self, name: str) -> Optional[Tool]:
"""
Get a specific tool by name.
Args:
name: Name of the tool to find.
Returns:
The Tool object if found, None otherwise.
Example:
>>> tool = await env.get_tool("echo_message")
>>> if tool:
... print(tool.description)
... print(tool.input_schema)
"""
tools = await self.list_tools()
for tool in tools:
if tool.name == name:
return tool
return None
async def has_tool(self, name: str) -> bool:
"""
Check if a tool exists.
Args:
name: Name of the tool to check.
Returns:
True if the tool exists, False otherwise.
"""
return await self.get_tool(name) is not None