File size: 2,565 Bytes
816634a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Data models for the Visual Memory Environment.

VisualMemoryAction has explicit Pydantic fields so the OpenEnv web
interface renders interactive form inputs on HF Spaces.
"""

from __future__ import annotations

import json as _json
from typing import Any, Union

from pydantic import ConfigDict, Field, TypeAdapter

from openenv.core.env_server.mcp_types import (
    CallToolAction,
    CallToolObservation,
    ListToolsAction,
    ListToolsObservation,
)
from openenv.core.env_server.types import Action, Observation, State

_mcp_action_adapter = TypeAdapter(Union[ListToolsAction, CallToolAction])

_AVAILABLE_TOOLS = (
    "list_tools, get_session_info, list_scenarios, load_scenario, "
    "reset_scenario, get_board_view, get_status, reveal_cell, "
    "inspect_region, flag_cell, unflag_cell, move_viewport, "
    "submit_solution, recall_log, get_action_history, get_progress_stats, "
    "auto_solve, peek_hidden_cell, undo_last_action"
)


class VisualMemoryAction(Action):
    """Action with explicit fields for the web UI and MCP compatibility."""

    model_config = ConfigDict(
        extra="forbid",
        validate_assignment=True,
        arbitrary_types_allowed=True,
    )

    tool_name: str = Field(
        default="list_tools",
        description=f"MCP tool to invoke. Available: {_AVAILABLE_TOOLS}",
    )
    arguments_json: str = Field(
        default="{}",
        description=(
            'Tool arguments as a JSON string. Examples: '
            '"{}" for no args, '
            '\'{"scenario_id":"hidden_grid_01"}\' for load_scenario, '
            '\'{"row":2,"col":3}\' for reveal_cell or flag_cell, '
            '\'{"flagged_positions":"[[0,1],[2,3]]"}\' for submit_solution'
        ),
    )

    @classmethod
    def model_validate(cls, data: Any, **kwargs: Any) -> Action:
        if isinstance(data, dict) and data.get("type") in ("call_tool", "list_tools"):
            return _mcp_action_adapter.validate_python(data)
        return super().model_validate(data, **kwargs)

    def to_mcp_action(self) -> Action:
        if self.tool_name == "list_tools":
            return ListToolsAction()
        args = _json.loads(self.arguments_json) if self.arguments_json else {}
        return CallToolAction(tool_name=self.tool_name, arguments=args)


VisualMemoryObservation = CallToolObservation
VisualMemoryState = State

__all__ = [
    "VisualMemoryAction",
    "VisualMemoryObservation",
    "VisualMemoryState",
    "CallToolAction",
    "CallToolObservation",
    "ListToolsAction",
    "ListToolsObservation",
]