File size: 3,272 Bytes
816634a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Visual Memory Environment HTTP Client.

Connects to a running Visual Memory OpenEnv server over HTTP/WebSocket.
Agents interact via MCP tools exposed through step(CallToolAction(...)).
"""

from __future__ import annotations

from typing import Any, Dict

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from openenv.core.env_server.mcp_types import (
    CallToolAction,
    ListToolsAction,
    Tool,
)

from .models import (
    VisualMemoryAction,
    VisualMemoryObservation,
    VisualMemoryState,
)


class VisualMemoryEnv(EnvClient[VisualMemoryAction, VisualMemoryObservation, VisualMemoryState]):
    """HTTP client for the Visual Memory Environment.

    Example:
        >>> async with VisualMemoryEnv(base_url="http://localhost:8000") as client:
        ...     result = await client.reset()
        ...     result = await client.step(
        ...         CallToolAction(tool_name="load_scenario", arguments={"scenario_id": "hidden_grid_01"})
        ...     )
    """

    def list_tools(self, use_cache: bool = True):
        if use_cache and hasattr(self, "_tools_cache") and self._tools_cache:
            return self._tools_cache
        import requests

        http_base = (
            self._ws_url
            .replace("ws://", "http://")
            .replace("wss://", "https://")
            .rstrip("/ws")
        )
        resp = requests.post(
            f"{http_base}/step",
            json={"action": {"type": "list_tools"}},
        )
        data = resp.json()
        raw_tools = data.get("observation", {}).get("tools", [])
        tools = [
            Tool(
                name=t["name"],
                description=t.get("description", ""),
                input_schema=t.get("input_schema", {}),
            )
            for t in raw_tools
        ]
        self._tools_cache = tools
        return tools

    def _step_payload(self, action: Any) -> Dict:
        if isinstance(action, ListToolsAction):
            return {"type": "list_tools"}
        if isinstance(action, CallToolAction):
            return {
                "type": "call_tool",
                "tool_name": action.tool_name,
                "arguments": action.arguments or {},
            }
        if hasattr(action, "model_dump"):
            return action.model_dump()
        return {"tool_name": getattr(action, "tool_name", ""), "arguments": {}}

    def _parse_result(self, payload: Dict) -> StepResult[VisualMemoryObservation]:
        obs_data = payload.get("observation", payload)
        observation = VisualMemoryObservation(
            tool_name=obs_data.get("tool_name", ""),
            result=obs_data.get("result"),
            error=obs_data.get("error"),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> VisualMemoryState:
        return VisualMemoryState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )