""" Visual Memory Environment HTTP Client. Connects to a running Visual Memory OpenEnv server over HTTP/WebSocket. Agents interact via MCP tools exposed through step(CallToolAction(...)). """ from __future__ import annotations from typing import Any, Dict from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from openenv.core.env_server.mcp_types import ( CallToolAction, ListToolsAction, Tool, ) from .models import ( VisualMemoryAction, VisualMemoryObservation, VisualMemoryState, ) class VisualMemoryEnv(EnvClient[VisualMemoryAction, VisualMemoryObservation, VisualMemoryState]): """HTTP client for the Visual Memory Environment. Example: >>> async with VisualMemoryEnv(base_url="http://localhost:8000") as client: ... result = await client.reset() ... result = await client.step( ... CallToolAction(tool_name="load_scenario", arguments={"scenario_id": "hidden_grid_01"}) ... ) """ def list_tools(self, use_cache: bool = True): if use_cache and hasattr(self, "_tools_cache") and self._tools_cache: return self._tools_cache import requests http_base = ( self._ws_url .replace("ws://", "http://") .replace("wss://", "https://") .rstrip("/ws") ) resp = requests.post( f"{http_base}/step", json={"action": {"type": "list_tools"}}, ) data = resp.json() raw_tools = data.get("observation", {}).get("tools", []) tools = [ Tool( name=t["name"], description=t.get("description", ""), input_schema=t.get("input_schema", {}), ) for t in raw_tools ] self._tools_cache = tools return tools def _step_payload(self, action: Any) -> Dict: if isinstance(action, ListToolsAction): return {"type": "list_tools"} if isinstance(action, CallToolAction): return { "type": "call_tool", "tool_name": action.tool_name, "arguments": action.arguments or {}, } if hasattr(action, "model_dump"): return action.model_dump() return {"tool_name": getattr(action, "tool_name", ""), "arguments": {}} def _parse_result(self, payload: Dict) -> StepResult[VisualMemoryObservation]: obs_data = payload.get("observation", payload) observation = VisualMemoryObservation( tool_name=obs_data.get("tool_name", ""), result=obs_data.get("result"), error=obs_data.get("error"), done=payload.get("done", False), reward=payload.get("reward"), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> VisualMemoryState: return VisualMemoryState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), )