wildfire_env / src /openenv /core /mcp_client.py
shankerram3's picture
Upload folder using huggingface_hub
b750f9a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
MCP Client classes for tool-calling environments.
This module provides client classes for interacting with MCP-enabled environments:
- MCPClientBase: Base class with shared tool discovery
- MCPToolClient: Client for tool-calling style (one tool per step)
These clients abstract away the MCP protocol details, providing a clean interface
for listing and calling tools on remote environments.
Example:
>>> from openenv.core.mcp_client import MCPToolClient
>>>
>>> with MCPToolClient(base_url="http://localhost:8000") as env:
... # Discover available tools
... tools = env.list_tools()
... print([t.name for t in tools])
...
... # Call a tool
... result = env.call_tool("echo_message", message="Hello!")
... print(result)
"""
from typing import Any, Dict, List, Optional
from .client_types import StepResult, StateT
from .env_client import EnvClient
from .env_server.mcp_types import (
CallToolAction,
CallToolObservation,
ListToolsAction,
ListToolsObservation,
Tool,
ToolError,
)
from .env_server.types import Observation, State
class MCPClientBase(EnvClient[Any, Observation, State]):
"""
Base class for MCP clients with tool discovery.
This class provides the common `list_tools()` method for discovering
available tools from an MCP-enabled environment. Subclasses implement
specific interaction patterns (tool-calling or CodeAct).
Attributes:
_tools_cache: Cached list of tools (populated on first `list_tools()` call)
"""
def __init__(
self,
base_url: str,
connect_timeout_s: float = 10.0,
message_timeout_s: float = 60.0,
provider: Optional[Any] = None,
):
"""
Initialize MCP client.
Args:
base_url: Base URL of the environment server (http:// or ws://).
connect_timeout_s: Timeout for establishing WebSocket connection.
message_timeout_s: Timeout for receiving responses to messages.
provider: Optional container/runtime provider for lifecycle management.
"""
super().__init__(
base_url=base_url,
connect_timeout_s=connect_timeout_s,
message_timeout_s=message_timeout_s,
provider=provider,
)
self._tools_cache: Optional[List[Tool]] = None
def list_tools(self, use_cache: bool = True) -> List[Tool]:
"""
Discover available tools from the environment.
Args:
use_cache: If True, return cached tools if available.
Set to False to force a fresh request.
Returns:
List of Tool objects with name, description, and input_schema.
Example:
>>> tools = env.list_tools()
>>> for tool in tools:
... print(f"{tool.name}: {tool.description}")
"""
if use_cache and self._tools_cache is not None:
return self._tools_cache
result = self.step(ListToolsAction())
self._tools_cache = result.observation.tools
return self._tools_cache
def _step_payload(self, action: Any) -> Dict[str, Any]:
"""Convert an Action object to the JSON data expected by the env server."""
if isinstance(action, ListToolsAction):
return {"type": "list_tools"}
elif isinstance(action, CallToolAction):
return {
"type": "call_tool",
"tool_name": action.tool_name,
"arguments": action.arguments,
}
else:
# For unknown actions, try to serialize as dict
if hasattr(action, "model_dump"):
return action.model_dump()
return {"action": str(action)}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]:
"""Convert a JSON response from the env server to StepResult[Observation]."""
obs_data = payload.get("observation", {})
# Check if this is a ListToolsObservation
if "tools" in obs_data:
tools = [
Tool(
name=t.get("name", ""),
description=t.get("description", ""),
input_schema=t.get("input_schema", t.get("inputSchema", {})),
)
for t in obs_data.get("tools", [])
]
observation = ListToolsObservation(
tools=tools,
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
# Check if this is a CallToolObservation
elif "tool_name" in obs_data:
error = None
if obs_data.get("error"):
error = ToolError(**obs_data["error"])
observation = CallToolObservation(
tool_name=obs_data.get("tool_name", ""),
result=obs_data.get("result"),
error=error,
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
else:
# Generic observation
observation = Observation(
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> State:
"""Convert a JSON response from the state endpoint to a State object."""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
class MCPToolClient(MCPClientBase):
"""
Client for tool-calling style MCP interactions.
Each step invokes a single tool. Use this for traditional function-calling
agent patterns where the agent decides which tool to call next.
This client provides convenience methods for tool discovery and invocation:
- `list_tools()`: Get all available tools with their schemas
- `call_tool(name, **kwargs)`: Invoke a tool by name with arguments
Example:
>>> with MCPToolClient(base_url="http://localhost:8000") as env:
... # Reset the environment
... env.reset()
...
... # Discover available tools
... tools = env.list_tools()
... print([t.name for t in tools]) # ['echo_message', 'echo_with_length']
...
... # Call a tool directly
... result = env.call_tool("echo_message", message="Hello!")
... print(result) # "Hello!"
...
... # Or use the full action interface
... from openenv.core.env_server.mcp_types import CallToolAction
... step_result = env.step(CallToolAction(
... tool_name="echo_with_length",
... arguments={"message": "Test"}
... ))
... print(step_result.observation.result)
"""
def call_tool(self, name: str, **kwargs: Any) -> Any:
"""
Call a tool by name.
This is a convenience method that creates a CallToolAction, executes it,
and returns the result directly. For more control, use `step()` with
a CallToolAction directly.
Args:
name: Name of the tool to invoke (must match a tool from `list_tools()`).
**kwargs: Arguments to pass to the tool. Must match the tool's input_schema.
Returns:
The tool's result. The type depends on the tool being called.
Raises:
RuntimeError: If the server returns an error response.
Example:
>>> result = env.call_tool("add", a=5, b=3)
>>> print(result) # 8
>>>
>>> result = env.call_tool("greet", name="Claude")
>>> print(result) # "Hello, Claude!"
"""
action = CallToolAction(tool_name=name, arguments=kwargs)
result = self.step(action)
obs = result.observation
# Check for transport/framework errors
if isinstance(obs, CallToolObservation) and obs.error is not None:
raise RuntimeError(
f"Tool '{name}' failed: {obs.error.message} "
f"(type: {obs.error.error_type.value})"
)
# Return the result
if isinstance(obs, CallToolObservation):
result = obs.result
# Handle FastMCP CallToolResult objects
# - As object: has .data attribute
# - As dict (from JSON): has "data" key
if hasattr(result, "data"):
return result.data
if isinstance(result, dict) and "data" in result:
return result["data"]
return result
# Fallback for unexpected observation types
return obs
def get_tool(self, name: str) -> Optional[Tool]:
"""
Get a specific tool by name.
Args:
name: Name of the tool to find.
Returns:
The Tool object if found, None otherwise.
Example:
>>> tool = env.get_tool("echo_message")
>>> if tool:
... print(tool.description)
... print(tool.input_schema)
"""
tools = self.list_tools()
for tool in tools:
if tool.name == name:
return tool
return None
def has_tool(self, name: str) -> bool:
"""
Check if a tool exists.
Args:
name: Name of the tool to check.
Returns:
True if the tool exists, False otherwise.
"""
return self.get_tool(name) is not None