Spaces:
Sleeping
Sleeping
| """ | |
| Visual Memory Environment HTTP Client. | |
| Connects to a running Visual Memory OpenEnv server over HTTP/WebSocket. | |
| Agents interact via MCP tools exposed through step(CallToolAction(...)). | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from openenv.core.env_server.mcp_types import ( | |
| CallToolAction, | |
| ListToolsAction, | |
| Tool, | |
| ) | |
| from .models import ( | |
| VisualMemoryAction, | |
| VisualMemoryObservation, | |
| VisualMemoryState, | |
| ) | |
| class VisualMemoryEnv(EnvClient[VisualMemoryAction, VisualMemoryObservation, VisualMemoryState]): | |
| """HTTP client for the Visual Memory Environment. | |
| Example: | |
| >>> async with VisualMemoryEnv(base_url="http://localhost:8000") as client: | |
| ... result = await client.reset() | |
| ... result = await client.step( | |
| ... CallToolAction(tool_name="load_scenario", arguments={"scenario_id": "hidden_grid_01"}) | |
| ... ) | |
| """ | |
| def list_tools(self, use_cache: bool = True): | |
| if use_cache and hasattr(self, "_tools_cache") and self._tools_cache: | |
| return self._tools_cache | |
| import requests | |
| http_base = ( | |
| self._ws_url | |
| .replace("ws://", "http://") | |
| .replace("wss://", "https://") | |
| .rstrip("/ws") | |
| ) | |
| resp = requests.post( | |
| f"{http_base}/step", | |
| json={"action": {"type": "list_tools"}}, | |
| ) | |
| data = resp.json() | |
| raw_tools = data.get("observation", {}).get("tools", []) | |
| tools = [ | |
| Tool( | |
| name=t["name"], | |
| description=t.get("description", ""), | |
| input_schema=t.get("input_schema", {}), | |
| ) | |
| for t in raw_tools | |
| ] | |
| self._tools_cache = tools | |
| return tools | |
| def _step_payload(self, action: Any) -> Dict: | |
| if isinstance(action, ListToolsAction): | |
| return {"type": "list_tools"} | |
| if isinstance(action, CallToolAction): | |
| return { | |
| "type": "call_tool", | |
| "tool_name": action.tool_name, | |
| "arguments": action.arguments or {}, | |
| } | |
| if hasattr(action, "model_dump"): | |
| return action.model_dump() | |
| return {"tool_name": getattr(action, "tool_name", ""), "arguments": {}} | |
| def _parse_result(self, payload: Dict) -> StepResult[VisualMemoryObservation]: | |
| obs_data = payload.get("observation", payload) | |
| observation = VisualMemoryObservation( | |
| tool_name=obs_data.get("tool_name", ""), | |
| result=obs_data.get("result"), | |
| error=obs_data.get("error"), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward"), | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> VisualMemoryState: | |
| return VisualMemoryState( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| ) | |