visual_memory / client.py
kdemon1011's picture
Upload folder using huggingface_hub
816634a verified
"""
Visual Memory Environment HTTP Client.
Connects to a running Visual Memory OpenEnv server over HTTP/WebSocket.
Agents interact via MCP tools exposed through step(CallToolAction(...)).
"""
from __future__ import annotations
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from openenv.core.env_server.mcp_types import (
CallToolAction,
ListToolsAction,
Tool,
)
from .models import (
VisualMemoryAction,
VisualMemoryObservation,
VisualMemoryState,
)
class VisualMemoryEnv(EnvClient[VisualMemoryAction, VisualMemoryObservation, VisualMemoryState]):
"""HTTP client for the Visual Memory Environment.
Example:
>>> async with VisualMemoryEnv(base_url="http://localhost:8000") as client:
... result = await client.reset()
... result = await client.step(
... CallToolAction(tool_name="load_scenario", arguments={"scenario_id": "hidden_grid_01"})
... )
"""
def list_tools(self, use_cache: bool = True):
if use_cache and hasattr(self, "_tools_cache") and self._tools_cache:
return self._tools_cache
import requests
http_base = (
self._ws_url
.replace("ws://", "http://")
.replace("wss://", "https://")
.rstrip("/ws")
)
resp = requests.post(
f"{http_base}/step",
json={"action": {"type": "list_tools"}},
)
data = resp.json()
raw_tools = data.get("observation", {}).get("tools", [])
tools = [
Tool(
name=t["name"],
description=t.get("description", ""),
input_schema=t.get("input_schema", {}),
)
for t in raw_tools
]
self._tools_cache = tools
return tools
def _step_payload(self, action: Any) -> Dict:
if isinstance(action, ListToolsAction):
return {"type": "list_tools"}
if isinstance(action, CallToolAction):
return {
"type": "call_tool",
"tool_name": action.tool_name,
"arguments": action.arguments or {},
}
if hasattr(action, "model_dump"):
return action.model_dump()
return {"tool_name": getattr(action, "tool_name", ""), "arguments": {}}
def _parse_result(self, payload: Dict) -> StepResult[VisualMemoryObservation]:
obs_data = payload.get("observation", payload)
observation = VisualMemoryObservation(
tool_name=obs_data.get("tool_name", ""),
result=obs_data.get("result"),
error=obs_data.get("error"),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> VisualMemoryState:
return VisualMemoryState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)