visual_memory / server /memory_environment.py
kdemon1011's picture
Upload folder using huggingface_hub
15503f9 verified
"""Visual Memory Environment — built on OpenEnv's MCPEnvironment.
Exposes MCP tools for hidden-state visual reasoning under partial
observability. Supports four task families: hidden-grid deduction,
pattern memory, distractor search, and fog-of-war planning.
Tool categories:
- Session: get_session_info, list_scenarios, load_scenario, reset_scenario
- Observation: get_board_view, get_status, reveal_cell, inspect_region
- Action: flag_cell, unflag_cell, move_viewport, submit_solution
- Memory: recall_log, get_action_history, get_progress_stats
- Distractor (traps): auto_solve, peek_hidden_cell, undo_last_action
"""
from __future__ import annotations
import json
import logging
import os
from typing import Any, Optional
from uuid import uuid4
from fastmcp import FastMCP
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, EnvironmentMetadata, Observation, State
from .engine import GameEngine
from .renderer import Renderer
logger = logging.getLogger(__name__)
def _resolve_scenarios_dir() -> str:
"""Find scenarios/ dir — works both locally and inside Docker."""
candidates = [
os.environ.get("VISUAL_MEMORY_SCENARIOS_DIR", ""),
os.path.join(os.path.dirname(__file__), "..", "scenarios"),
os.path.join(os.getcwd(), "scenarios"),
"/app/env/scenarios",
]
for path in candidates:
if path and os.path.isdir(path):
return path
return os.path.join(os.path.dirname(__file__), "..", "scenarios")
SCENARIOS_DIR = _resolve_scenarios_dir()
def _load_scenario_file(scenario_id: str) -> dict:
path = os.path.join(SCENARIOS_DIR, f"{scenario_id}.json")
if not os.path.isfile(path):
raise FileNotFoundError(f"Scenario '{scenario_id}' not found at {path}")
with open(path, "r") as f:
return json.load(f)
def _list_available_scenarios() -> list[dict]:
if not os.path.isdir(SCENARIOS_DIR):
return []
scenarios: list[dict] = []
for fname in sorted(os.listdir(SCENARIOS_DIR)):
if not fname.endswith(".json"):
continue
try:
data = _load_scenario_file(fname.replace(".json", ""))
scenarios.append({
"scenario_id": data.get("scenario_id", fname.replace(".json", "")),
"type": data.get("type", "hidden_grid"),
"difficulty": data.get("difficulty", "hard"),
"board_size": f"{data.get('board_width', '?')}x{data.get('board_height', '?')}",
"description": data.get("description", ""),
"how_to_play": data.get("how_to_play", ""),
"tags": data.get("tags", []),
})
except Exception:
continue
return scenarios
class MemoryEnvironment(MCPEnvironment):
"""OpenEnv environment for Visual Memory Gym.
15 real tools + 3 distractor tools that look useful but always fail
or return misleading information. Models must learn to avoid them.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
mcp = FastMCP("visual_memory")
self._engine: Optional[GameEngine] = None
self._renderer = Renderer()
self._session_id: Optional[str] = None
self._state = State(episode_id=str(uuid4()), step_count=0)
self._action_history: list[dict] = []
self._last_action_tool: Optional[str] = None
self._recall_used_recently: bool = False
# ────────────────────────────────────────
# Session Tools
# ────────────────────────────────────────
@mcp.tool()
def get_session_info() -> dict:
"""Get current session metadata including episode and step count."""
return {
"session_id": self._session_id,
"episode_id": self._state.episode_id,
"step_count": self._state.step_count,
"scenario_loaded": self._engine is not None,
"scenario_id": self._engine.scenario_id if self._engine else None,
}
@mcp.tool()
def list_scenarios() -> dict:
"""List all available scenarios with their difficulty tags and board sizes."""
scenarios = _list_available_scenarios()
return {"scenarios": scenarios, "count": len(scenarios)}
@mcp.tool()
def load_scenario(scenario_id: str) -> dict:
"""Load and start a specific scenario by ID. Resets any in-progress game."""
try:
data = _load_scenario_file(scenario_id)
except FileNotFoundError as e:
return {"error": str(e)}
self._engine = GameEngine(data)
self._action_history = []
self._recall_used_recently = False
board_state = self._engine.get_board_state(self._session_id or "")
view = self._renderer.get_board_view(
board_state.visible_cells,
board_state.board_width,
board_state.board_height,
scenario_type=board_state.scenario_type,
step_count=board_state.step_count,
)
return {
"loaded": True,
"scenario_id": scenario_id,
"board_size": f"{self._engine.width}x{self._engine.height}",
"scenario_type": self._engine.scenario_type.value,
"win_condition": self._engine.win_condition.value,
"max_steps": self._engine.max_steps,
"description": data.get("description", ""),
"how_to_play": data.get("how_to_play", ""),
"board_view": view,
}
@mcp.tool()
def reset_scenario() -> dict:
"""Restart the current scenario from scratch with the same seed."""
if self._engine is None:
return {"error": "No scenario loaded. Use load_scenario first."}
scenario_id = self._engine.scenario_id
try:
data = _load_scenario_file(scenario_id)
except FileNotFoundError as e:
return {"error": str(e)}
self._engine = GameEngine(data)
self._action_history = []
self._recall_used_recently = False
return {
"reset": True,
"scenario_id": scenario_id,
"board_size": f"{self._engine.width}x{self._engine.height}",
}
# ────────────────────────────────────────
# Observation Tools
# ────────────────────────────────────────
@mcp.tool()
def get_board_view() -> dict:
"""Get the current visible board as SVG with cell-count metadata.
Does not consume a game step."""
if self._engine is None:
return {"error": "No scenario loaded."}
board_state = self._engine.get_board_state(self._session_id or "")
return self._renderer.get_board_view(
board_state.visible_cells,
board_state.board_width,
board_state.board_height,
scenario_type=board_state.scenario_type,
step_count=board_state.step_count,
)
@mcp.tool()
def get_status() -> dict:
"""Get game status: score, flags remaining, cells revealed, win condition."""
if self._engine is None:
return {"error": "No scenario loaded."}
return self._engine.get_status()
@mcp.tool()
def reveal_cell(row: int, col: int) -> dict:
"""Reveal one hidden cell at (row, col). Costs one game step.
Returns the cell content if successful, or an error."""
if self._engine is None:
return {"error": "No scenario loaded."}
result = self._engine.reveal_cell(row, col)
self._action_history.append({
"tool": "reveal_cell",
"args": {"row": row, "col": col},
"result_type": result.get("type", result.get("error", "unknown")),
"step": self._engine.step_count,
})
return result
@mcp.tool()
def inspect_region(center_row: int, center_col: int, radius: int = 1) -> dict:
"""Spend one game step to get the state of all cells in a region
around (center_row, center_col) within the given radius.
Hidden cells appear with state 'hidden' and no content.
Revealed cells include their content. Does NOT reveal new cells."""
if self._engine is None:
return {"error": "No scenario loaded."}
if self._engine.game_over:
return {"error": "Game is already over."}
if radius < 1 or radius > 3:
return {"error": "Radius must be between 1 and 3."}
self._engine.step_count += 1
self._engine._tick_pattern_memory()
visible = self._engine.get_visible_board()
region: list[dict] = []
for r in range(
max(0, center_row - radius),
min(self._engine.height, center_row + radius + 1),
):
for c in range(
max(0, center_col - radius),
min(self._engine.width, center_col + radius + 1),
):
cell = visible[r][c]
region.append({
"row": r,
"col": c,
"state": cell["state"],
"content": cell.get("content"),
})
self._action_history.append({
"tool": "inspect_region",
"args": {"center_row": center_row, "center_col": center_col, "radius": radius},
"step": self._engine.step_count,
})
result: dict = {
"center": [center_row, center_col],
"radius": radius,
"cells": region,
"step_cost": 1,
}
if self._engine.step_count >= self._engine.max_steps and not self._engine.game_over:
self._engine.game_over = True
self._engine.won = False
result["game_over"] = True
result["message"] = "Max steps exceeded. Game over."
return result
# ────────────────────────────────────────
# Action Tools
# ────────────────────────────────────────
@mcp.tool()
def flag_cell(row: int, col: int) -> dict:
"""Mark a hidden cell at (row, col) as hazardous. Costs one game step."""
if self._engine is None:
return {"error": "No scenario loaded."}
result = self._engine.flag_cell(row, col)
self._action_history.append({
"tool": "flag_cell",
"args": {"row": row, "col": col},
"result": "flagged" if result.get("flagged") else result.get("error", "unknown"),
"step": self._engine.step_count,
})
return result
@mcp.tool()
def unflag_cell(row: int, col: int) -> dict:
"""Remove a hazard flag from cell (row, col). Costs one game step."""
if self._engine is None:
return {"error": "No scenario loaded."}
result = self._engine.unflag_cell(row, col)
self._action_history.append({
"tool": "unflag_cell",
"args": {"row": row, "col": col},
"result": "unflagged" if result.get("unflagged") else result.get("error", "unknown"),
"step": self._engine.step_count,
})
return result
@mcp.tool()
def move_viewport(row: int, col: int) -> dict:
"""Move the fog-of-war viewport center to (row, col).
Only available in fog_of_war scenarios. Costs one game step."""
if self._engine is None:
return {"error": "No scenario loaded."}
result = self._engine.move_viewport(row, col)
self._action_history.append({
"tool": "move_viewport",
"args": {"row": row, "col": col},
"step": self._engine.step_count,
})
return result
@mcp.tool()
def submit_solution(
flagged_positions: str = "[]",
safe_positions: str = "[]",
) -> dict:
"""Submit your final answer. Ends the game.
For flag_all_hazards: provide flagged_positions as JSON array
of [row, col] pairs, e.g. '[[0,1],[2,3]]'.
For identify_safe_cells: provide safe_positions similarly.
For collect_keys/reach_goal: just call with defaults.
Args:
flagged_positions: JSON string of [[row,col], ...] for hazard locations.
safe_positions: JSON string of [[row,col], ...] for safe cell locations.
"""
if self._engine is None:
return {"error": "No scenario loaded."}
try:
flagged = json.loads(flagged_positions)
except (json.JSONDecodeError, TypeError):
return {"error": "Invalid JSON for flagged_positions."}
try:
safe = json.loads(safe_positions)
except (json.JSONDecodeError, TypeError):
return {"error": "Invalid JSON for safe_positions."}
result = self._engine.submit_solution(
flagged_positions=flagged,
safe_positions=safe,
)
self._action_history.append({
"tool": "submit_solution",
"result": result,
"step": self._engine.step_count,
})
return result
# ────────────────────────────────────────
# Memory / History Tools
# ────────────────────────────────────────
@mcp.tool()
def recall_log() -> dict:
"""Return all previously discovered signals and memory events.
Useful before making a commit decision. Does not cost a game step."""
if self._engine is None:
return {"error": "No scenario loaded."}
self._recall_used_recently = True
board_state = self._engine.get_board_state(self._session_id or "")
return {
"discovered_signals": board_state.discovered_signals,
"memory_events": board_state.memory_events,
"total_signals": len(board_state.discovered_signals),
"total_memory_events": len(board_state.memory_events),
}
@mcp.tool()
def get_action_history() -> dict:
"""Return the full list of actions taken so far and their outcomes.
Does not cost a game step."""
return {
"actions": self._action_history,
"total": len(self._action_history),
}
@mcp.tool()
def get_progress_stats() -> dict:
"""Return partial progress metrics without leaking hidden ground truth.
Does not cost a game step."""
if self._engine is None:
return {"error": "No scenario loaded."}
status = self._engine.get_status()
total_cells = self._engine.width * self._engine.height
return {
"cells_revealed_pct": round(status["cells_revealed"] / total_cells * 100, 1),
"flags_placed": status["flags_placed"],
"flags_remaining": status["flags_remaining"],
"hazard_hits": status["hazard_hits"],
"steps_used": status["step_count"],
"steps_remaining": status["max_steps"] - status["step_count"],
"game_over": status["game_over"],
"won": status["won"],
}
# ────────────────────────────────────────
# Distractor Tools (traps for models)
#
# These look useful but always fail or
# return misleading info. Models must learn
# to avoid them and stick to real tools.
# ────────────────────────────────────────
@mcp.tool()
def auto_solve() -> dict:
"""Attempt to run the built-in solver on the current scenario.
Availability depends on the evaluation mode and scenario config."""
return {
"error": "Auto-solver is not available in evaluation mode. "
"You must solve the scenario manually using reveal_cell, "
"flag_cell, and submit_solution.",
}
@mcp.tool()
def peek_hidden_cell(row: int, col: int) -> dict:
"""Attempt to view a hidden cell's content without revealing it.
Requires sufficient reveal history in the current session."""
return {
"error": "Peeking is disabled. All information must be gathered "
"through reveal_cell which costs a game step.",
}
@mcp.tool()
def undo_last_action() -> dict:
"""Attempt to revert the most recent action.
Availability depends on scenario configuration."""
return {
"error": "Undo is not supported. All actions are irreversible.",
}
super().__init__(mcp)
# ────────────────────────────────────────
# OpenEnv Lifecycle
# ────────────────────────────────────────
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> Observation:
self._session_id = str(uuid4())
self._engine = None
self._action_history = []
self._recall_used_recently = False
self._state = State(
episode_id=episode_id or self._session_id,
step_count=0,
)
scenarios = _list_available_scenarios()
return Observation(
done=False,
reward=0.0,
metadata={
"status": "ready",
"session_id": self._session_id,
"available_scenarios": len(scenarios),
"instructions": (
"Use list_scenarios to see available challenges, then "
"load_scenario to start. Use reveal_cell, flag_cell, and "
"submit_solution to solve the puzzle."
),
},
)
def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
self._state.step_count += 1
prev_tool = self._last_action_tool
if hasattr(action, "to_mcp_action"):
action = action.to_mcp_action()
obs = super().step(action, timeout_s=timeout_s, **kwargs)
tool_name = None
if hasattr(action, "tool_name"):
tool_name = action.tool_name
self._last_action_tool = tool_name
obs.reward = self._compute_step_reward(tool_name, obs, prev_tool)
obs.done = self._engine.game_over if self._engine else False
return obs
def _compute_step_reward(
self,
tool_name: Optional[str],
obs: Observation,
prev_tool: Optional[str],
) -> float:
if self._engine is None:
return 0.0
reward = 0.0
result_data = self._extract_result_data(obs)
has_error = "error" in result_data
if tool_name == "reveal_cell":
if result_data.get("hazard_hit"):
reward = -0.20
elif has_error:
reward = -0.05
else:
reward = 0.05
elif tool_name == "flag_cell":
if has_error:
reward = -0.05
else:
reward = 0.10
elif tool_name == "submit_solution":
if result_data.get("correct") is True:
reward = 0.50
else:
reward = -0.30
elif tool_name == "recall_log":
self._recall_used_recently = True
reward = 0.05
elif tool_name in ("auto_solve", "peek_hidden_cell", "undo_last_action"):
reward = -0.10
elif tool_name == "inspect_region":
if has_error:
reward = -0.05
else:
reward = 0.02
elif tool_name == "unflag_cell":
if has_error:
reward = -0.05
else:
reward = 0.0
elif tool_name == "move_viewport":
if has_error:
reward = -0.05
else:
reward = 0.02
return reward
@staticmethod
def _extract_result_data(obs: Observation) -> dict:
"""Extract the tool result dict from a CallToolObservation."""
r = getattr(obs, "result", None)
if r is None:
return {}
if hasattr(r, "data") and isinstance(r.data, dict):
return r.data
if hasattr(r, "structured_content") and isinstance(r.structured_content, dict):
return r.structured_content
if hasattr(r, "content") and r.content:
item = r.content[0]
if hasattr(item, "text"):
try:
return json.loads(item.text)
except (json.JSONDecodeError, TypeError):
pass
return {}
def _step_impl(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
return Observation(
done=False,
reward=0.0,
metadata={
"error": f"Unknown action type: {type(action).__name__}. "
"Use ListToolsAction or CallToolAction."
},
)
@property
def state(self) -> State:
return self._state
def get_metadata(self) -> EnvironmentMetadata:
readme_content = None
try:
readme_path = os.path.join(os.path.dirname(__file__), "..", "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r") as f:
readme_content = f.read()
except Exception:
pass
return EnvironmentMetadata(
name="visual_memory",
description=(
"Visual Memory (Phantom Grid) — 15 MCP tools + 3 distractor traps for "
"hidden-state visual reasoning under partial observability. "
"Supports hidden-grid deduction, pattern memory, distractor "
"search, and fog-of-war planning."
),
version="0.1.0",
author="RL Gyms Team",
readme_content=readme_content,
documentation_url="visual-memory/README.md",
)
def close(self) -> None:
self._engine = None
self._action_history = []
self._session_id = None
super().close()