visual_memory / server /engine.py
kdemon1011's picture
Upload folder using huggingface_hub
816634a verified
"""Hidden-state game engine for Visual Memory Gym.
Manages in-memory board state, hidden cell contents, move validation,
and win/loss conditions across four task families:
1. hidden_grid β€” deduce hazard locations from signal clues
2. pattern_memory β€” recall briefly-shown cell contents
3. distractor_search β€” identify targets among visually similar decoys
4. fog_of_war β€” plan under limited viewport radius
"""
from __future__ import annotations
import copy
from enum import Enum
from typing import Any
import numpy as np
from pydantic import BaseModel, Field
class CellType(str, Enum):
EMPTY = "empty"
HAZARD = "hazard"
SIGNAL = "signal"
KEY = "key"
DECOY = "decoy"
GOAL = "goal"
class CellState(str, Enum):
HIDDEN = "hidden"
REVEALED = "revealed"
FLAGGED = "flagged"
FADED = "faded"
class ScenarioType(str, Enum):
HIDDEN_GRID = "hidden_grid"
PATTERN_MEMORY = "pattern_memory"
DISTRACTOR_SEARCH = "distractor_search"
FOG_OF_WAR = "fog_of_war"
class SignalMode(str, Enum):
COUNT = "count"
DIRECTIONAL = "directional"
RANGE = "range"
PARTIAL = "partial"
class WinCondition(str, Enum):
FLAG_ALL_HAZARDS = "flag_all_hazards"
COLLECT_KEYS = "collect_keys"
IDENTIFY_SAFE = "identify_safe_cells"
REACH_GOAL = "reach_goal"
class BoardState(BaseModel):
"""Serializable snapshot of the game state (visible portion only)."""
session_id: str = ""
scenario_id: str = ""
scenario_type: str = "hidden_grid"
step_count: int = 0
board_width: int = 0
board_height: int = 0
visible_cells: list[list[dict]] = Field(default_factory=list)
discovered_signals: list[dict] = Field(default_factory=list)
memory_events: list[dict] = Field(default_factory=list)
game_over: bool = False
won: bool = False
flags_remaining: int = 0
cells_revealed: int = 0
hazard_hits: int = 0
keys_collected: int = 0
max_steps: int = 50
NEIGHBOR_OFFSETS = [
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1),
]
DIRECTION_NAMES = {
(-1, -1): "NW", (-1, 0): "N", (-1, 1): "NE",
(0, -1): "W", (0, 1): "E",
(1, -1): "SW", (1, 0): "S", (1, 1): "SE",
}
class GameEngine:
"""In-memory game engine for the Visual Memory gym.
Deterministic given a seed. All state lives in Python memory.
"""
def __init__(self, scenario: dict, seed: int | None = None):
self.scenario_id: str = scenario["scenario_id"]
self.scenario_type = ScenarioType(scenario.get("type", "hidden_grid"))
self.width: int = scenario["board_width"]
self.height: int = scenario["board_height"]
self.max_steps: int = scenario.get("max_steps", 50)
self.max_hazard_reveals: int = scenario.get("max_hazard_reveals", 3)
self.signal_mode = SignalMode(scenario.get("signal_mode", "count"))
self.win_condition = WinCondition(
scenario.get("win_condition", {}).get("type", "flag_all_hazards")
)
resolved_seed = seed if seed is not None else scenario.get("seed", 42)
self._rng = np.random.default_rng(resolved_seed)
self.step_count: int = 0
self.hazard_hits: int = 0
self.keys_collected: int = 0
self.cells_revealed: int = 0
self.game_over: bool = False
self.won: bool = False
if "layout" in scenario:
self._hidden = self._load_explicit_layout(scenario["layout"])
else:
self._hidden = self._generate_board(scenario)
self._visible: list[list[dict]] = [
[{"state": CellState.HIDDEN.value, "content": None} for _ in range(self.width)]
for _ in range(self.height)
]
total_hazards = sum(
1
for r in range(self.height)
for c in range(self.width)
if self._hidden[r][c]["type"] == CellType.HAZARD.value
)
self.total_flags: int = scenario.get("flags_count", total_hazards + 3)
self.flags_placed: int = 0
self.total_keys: int = sum(
1
for r in range(self.height)
for c in range(self.width)
if self._hidden[r][c]["type"] == CellType.KEY.value
)
self._discovered_signals: list[dict] = []
self._memory_events: list[dict] = []
self._action_log: list[dict] = []
self._viewport_center: list[int] | None = scenario.get("start_position")
self._viewport_radius: int | None = scenario.get("viewport_radius")
self._flash_cells: list[list[int]] = scenario.get("flash_cells", [])
self._flash_until_step: int = scenario.get("flash_until_step", 0)
if self.scenario_type == ScenarioType.PATTERN_MEMORY and self._flash_cells:
for rc in self._flash_cells:
r, c = rc[0], rc[1]
cell = self._hidden[r][c]
self._visible[r][c] = {
"state": CellState.REVEALED.value,
"content": copy.deepcopy(cell),
}
self._memory_events.append({
"step": 0,
"event": "flash_shown",
"row": r,
"col": c,
"content": copy.deepcopy(cell),
})
# ─── Board Generation ───────────────────────────────────────────
def _load_explicit_layout(self, layout: list[list[dict]]) -> list[list[dict]]:
board: list[list[dict]] = []
for row_data in layout:
row: list[dict] = []
for cell in row_data:
row.append({
"type": cell.get("type", CellType.EMPTY.value),
"value": cell.get("value"),
"properties": cell.get("properties", {}),
})
board.append(row)
return board
def _generate_board(self, scenario: dict) -> list[list[dict]]:
hazard_count = scenario.get("hazard_count", 10)
key_count = scenario.get("key_count", 0)
decoy_count = scenario.get("decoy_count", 0)
goal_count = 1 if self.win_condition == WinCondition.REACH_GOAL else 0
total_cells = self.width * self.height
total_special = hazard_count + key_count + decoy_count + goal_count
if total_special > total_cells:
raise ValueError(
f"Cannot place {total_special} special cells on a "
f"{self.width}x{self.height} board ({total_cells} cells)"
)
positions = self._rng.permutation(total_cells)
board: list[list[dict]] = [
[{"type": CellType.EMPTY.value, "value": None, "properties": {}} for _ in range(self.width)]
for _ in range(self.height)
]
idx = 0
for _ in range(hazard_count):
r, c = divmod(int(positions[idx]), self.width)
board[r][c] = {"type": CellType.HAZARD.value, "value": None, "properties": {}}
idx += 1
for i in range(key_count):
r, c = divmod(int(positions[idx]), self.width)
board[r][c] = {"type": CellType.KEY.value, "value": f"key_{i}", "properties": {}}
idx += 1
for i in range(decoy_count):
r, c = divmod(int(positions[idx]), self.width)
board[r][c] = {"type": CellType.DECOY.value, "value": f"decoy_{i}", "properties": {}}
idx += 1
if goal_count:
r, c = divmod(int(positions[idx]), self.width)
board[r][c] = {"type": CellType.GOAL.value, "value": None, "properties": {}}
idx += 1
self._compute_signals(board)
return board
def _compute_signals(self, board: list[list[dict]]) -> None:
for r in range(self.height):
for c in range(self.width):
if board[r][c]["type"] != CellType.EMPTY.value:
continue
if self.signal_mode == SignalMode.COUNT:
count = self._count_adjacent_hazards(board, r, c)
if count > 0:
board[r][c] = {
"type": CellType.SIGNAL.value,
"value": count,
"properties": {"mode": "count"},
}
elif self.signal_mode == SignalMode.DIRECTIONAL:
directions = self._get_hazard_directions(board, r, c)
if directions:
board[r][c] = {
"type": CellType.SIGNAL.value,
"value": directions,
"properties": {"mode": "directional"},
}
elif self.signal_mode == SignalMode.RANGE:
count = self._count_adjacent_hazards(board, r, c)
if count > 0:
noise = int(self._rng.integers(0, 2))
low = max(0, count - noise)
high = count + int(self._rng.integers(0, 2))
board[r][c] = {
"type": CellType.SIGNAL.value,
"value": {"min": low, "max": high},
"properties": {"mode": "range"},
}
elif self.signal_mode == SignalMode.PARTIAL:
directions = self._get_hazard_directions(board, r, c)
if directions:
shown = max(1, len(directions) // 2)
indices = self._rng.choice(
len(directions), size=shown, replace=False
)
subset = [directions[i] for i in sorted(indices)]
board[r][c] = {
"type": CellType.SIGNAL.value,
"value": subset,
"properties": {
"mode": "partial",
"total_hint": len(directions),
},
}
def _count_adjacent_hazards(self, board: list[list[dict]], r: int, c: int) -> int:
count = 0
for dr, dc in NEIGHBOR_OFFSETS:
nr, nc = r + dr, c + dc
if 0 <= nr < self.height and 0 <= nc < self.width:
if board[nr][nc]["type"] == CellType.HAZARD.value:
count += 1
return count
def _get_hazard_directions(self, board: list[list[dict]], r: int, c: int) -> list[str]:
dirs: list[str] = []
for (dr, dc), name in DIRECTION_NAMES.items():
nr, nc = r + dr, c + dc
if 0 <= nr < self.height and 0 <= nc < self.width:
if board[nr][nc]["type"] == CellType.HAZARD.value:
dirs.append(name)
return dirs
# ─── Pattern Memory Phase ───────────────────────────────────────
def _tick_pattern_memory(self) -> None:
if self.scenario_type != ScenarioType.PATTERN_MEMORY:
return
if self.step_count != self._flash_until_step:
return
for rc in self._flash_cells:
r, c = rc[0], rc[1]
if self._visible[r][c]["state"] == CellState.REVEALED.value:
self._visible[r][c] = {"state": CellState.FADED.value, "content": None}
self._memory_events.append({
"step": self.step_count,
"event": "flash_faded",
"row": r,
"col": c,
})
# ─── Core Actions ───────────────────────────────────────────────
def reveal_cell(self, row: int, col: int) -> dict:
if self.game_over:
return {"error": "Game is already over.", "row": row, "col": col}
if not self._in_bounds(row, col):
return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col}
vis = self._visible[row][col]
if vis["state"] in (CellState.REVEALED.value, CellState.FLAGGED.value):
return {
"error": f"Cell ({row},{col}) is already {vis['state']}.",
"row": row,
"col": col,
}
if self._viewport_radius is not None and self._viewport_center is not None:
vr, vc = self._viewport_center
if abs(row - vr) > self._viewport_radius or abs(col - vc) > self._viewport_radius:
return {
"error": f"({row},{col}) is outside your current viewport.",
"row": row,
"col": col,
}
self.step_count += 1
self._tick_pattern_memory()
hidden = self._hidden[row][col]
cell_type = hidden["type"]
self._visible[row][col] = {
"state": CellState.REVEALED.value,
"content": copy.deepcopy(hidden),
}
self.cells_revealed += 1
result: dict[str, Any] = {
"row": row,
"col": col,
"type": cell_type,
"value": hidden.get("value"),
"properties": hidden.get("properties", {}),
}
if cell_type == CellType.SIGNAL.value:
self._discovered_signals.append(result)
if cell_type == CellType.HAZARD.value:
self.hazard_hits += 1
result["hazard_hit"] = True
if self.hazard_hits >= self.max_hazard_reveals:
self.game_over = True
self.won = False
result["game_over"] = True
result["message"] = "Too many hazards revealed. Game over."
if cell_type == CellType.KEY.value:
self.keys_collected += 1
result["key_collected"] = True
if (
self.win_condition == WinCondition.COLLECT_KEYS
and self.keys_collected >= self.total_keys
):
self.game_over = True
self.won = True
result["game_over"] = True
result["message"] = "All keys collected. You win!"
if cell_type == CellType.GOAL.value and self.win_condition == WinCondition.REACH_GOAL:
self.game_over = True
self.won = True
result["game_over"] = True
result["message"] = "Goal reached. You win!"
if self.step_count >= self.max_steps and not self.game_over:
self.game_over = True
self.won = False
result["game_over"] = True
result["message"] = "Max steps exceeded. Game over."
self._action_log.append({
"action": "reveal",
"row": row,
"col": col,
"step": self.step_count,
"result_type": cell_type,
})
return result
def flag_cell(self, row: int, col: int) -> dict:
if self.game_over:
return {"error": "Game is already over.", "row": row, "col": col}
if not self._in_bounds(row, col):
return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col}
vis = self._visible[row][col]
if vis["state"] == CellState.REVEALED.value:
return {"error": f"Cell ({row},{col}) is already revealed; cannot flag.", "row": row, "col": col}
if vis["state"] == CellState.FLAGGED.value:
return {"error": f"Cell ({row},{col}) is already flagged.", "row": row, "col": col}
if self.flags_placed >= self.total_flags:
return {"error": "No flags remaining.", "row": row, "col": col}
self.step_count += 1
self._tick_pattern_memory()
self._visible[row][col] = {"state": CellState.FLAGGED.value, "content": None}
self.flags_placed += 1
self._action_log.append({"action": "flag", "row": row, "col": col, "step": self.step_count})
self._check_flag_win()
result: dict[str, Any] = {
"row": row,
"col": col,
"flagged": True,
"flags_remaining": self.total_flags - self.flags_placed,
}
if self.game_over and self.won:
result["game_over"] = True
result["message"] = "All hazards correctly flagged. You win!"
if self.step_count >= self.max_steps and not self.game_over:
self.game_over = True
self.won = False
result["game_over"] = True
result["message"] = "Max steps exceeded. Game over."
return result
def unflag_cell(self, row: int, col: int) -> dict:
if self.game_over:
return {"error": "Game is already over.", "row": row, "col": col}
if not self._in_bounds(row, col):
return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col}
if self._visible[row][col]["state"] != CellState.FLAGGED.value:
return {"error": f"Cell ({row},{col}) is not flagged.", "row": row, "col": col}
self.step_count += 1
self._tick_pattern_memory()
self._visible[row][col] = {"state": CellState.HIDDEN.value, "content": None}
self.flags_placed -= 1
self._action_log.append({"action": "unflag", "row": row, "col": col, "step": self.step_count})
result: dict[str, Any] = {
"row": row,
"col": col,
"unflagged": True,
"flags_remaining": self.total_flags - self.flags_placed,
}
if self.step_count >= self.max_steps and not self.game_over:
self.game_over = True
self.won = False
result["game_over"] = True
result["message"] = "Max steps exceeded. Game over."
return result
def move_viewport(self, row: int, col: int) -> dict:
if self.scenario_type != ScenarioType.FOG_OF_WAR:
return {"error": "move_viewport is only available in fog_of_war scenarios."}
if self.game_over:
return {"error": "Game is already over."}
if not self._in_bounds(row, col):
return {"error": f"({row},{col}) is out of bounds."}
self.step_count += 1
self._tick_pattern_memory()
self._viewport_center = [row, col]
self._action_log.append({
"action": "move_viewport",
"row": row,
"col": col,
"step": self.step_count,
})
if self.step_count >= self.max_steps and not self.game_over:
self.game_over = True
self.won = False
return {
"viewport_center": [row, col],
"viewport_radius": self._viewport_radius,
"visible_area": self._get_viewport_bounds(),
}
def submit_solution(
self,
flagged_positions: list[list[int]] | None = None,
safe_positions: list[list[int]] | None = None,
) -> dict:
if self.game_over:
return {"error": "Game is already over."}
self.step_count += 1
self.game_over = True
if self.win_condition == WinCondition.FLAG_ALL_HAZARDS:
return self._judge_flag_solution(flagged_positions or [])
elif self.win_condition == WinCondition.IDENTIFY_SAFE:
return self._judge_safe_solution(safe_positions or [])
elif self.win_condition == WinCondition.COLLECT_KEYS:
success = self.keys_collected >= self.total_keys
self.won = success
return {
"correct": success,
"keys_collected": self.keys_collected,
"keys_required": self.total_keys,
}
elif self.win_condition == WinCondition.REACH_GOAL:
self.won = False
return {"correct": False, "message": "Goal was not reached before submission."}
return {"error": "Unknown win condition."}
# ─── State Queries ──────────────────────────────────────────────
def get_visible_board(self) -> list[list[dict]]:
if self._viewport_radius is None or self._viewport_center is None:
return copy.deepcopy(self._visible)
vr, vc = self._viewport_center
rad = self._viewport_radius
fog_board: list[list[dict]] = [
[{"state": "fog", "content": None} for _ in range(self.width)]
for _ in range(self.height)
]
for r in range(max(0, vr - rad), min(self.height, vr + rad + 1)):
for c in range(max(0, vc - rad), min(self.width, vc + rad + 1)):
fog_board[r][c] = copy.deepcopy(self._visible[r][c])
return fog_board
def get_status(self) -> dict:
return {
"scenario_id": self.scenario_id,
"scenario_type": self.scenario_type.value,
"step_count": self.step_count,
"max_steps": self.max_steps,
"board_size": f"{self.width}x{self.height}",
"cells_revealed": self.cells_revealed,
"hazard_hits": self.hazard_hits,
"max_hazard_reveals": self.max_hazard_reveals,
"keys_collected": self.keys_collected,
"total_keys": self.total_keys,
"flags_placed": self.flags_placed,
"flags_remaining": self.total_flags - self.flags_placed,
"game_over": self.game_over,
"won": self.won,
"win_condition": self.win_condition.value,
}
def get_board_state(self, session_id: str = "") -> BoardState:
return BoardState(
session_id=session_id,
scenario_id=self.scenario_id,
scenario_type=self.scenario_type.value,
step_count=self.step_count,
board_width=self.width,
board_height=self.height,
visible_cells=self.get_visible_board(),
discovered_signals=copy.deepcopy(self._discovered_signals),
memory_events=copy.deepcopy(self._memory_events),
game_over=self.game_over,
won=self.won,
flags_remaining=self.total_flags - self.flags_placed,
cells_revealed=self.cells_revealed,
hazard_hits=self.hazard_hits,
keys_collected=self.keys_collected,
max_steps=self.max_steps,
)
def get_hidden_board(self) -> list[list[dict]]:
"""Full hidden board β€” for reward computation only, never sent to agent."""
return copy.deepcopy(self._hidden)
def get_action_log(self) -> list[dict]:
return copy.deepcopy(self._action_log)
# ─── Internal Helpers ───────────────────────────────────────────
def _in_bounds(self, row: int, col: int) -> bool:
return 0 <= row < self.height and 0 <= col < self.width
def _get_viewport_bounds(self) -> dict:
if self._viewport_center is None or self._viewport_radius is None:
return {
"r_min": 0,
"r_max": self.height - 1,
"c_min": 0,
"c_max": self.width - 1,
}
vr, vc = self._viewport_center
rad = self._viewport_radius
return {
"r_min": max(0, vr - rad),
"r_max": min(self.height - 1, vr + rad),
"c_min": max(0, vc - rad),
"c_max": min(self.width - 1, vc + rad),
}
def _check_flag_win(self) -> None:
if self.win_condition != WinCondition.FLAG_ALL_HAZARDS:
return
for r in range(self.height):
for c in range(self.width):
is_hazard = self._hidden[r][c]["type"] == CellType.HAZARD.value
is_flagged = self._visible[r][c]["state"] == CellState.FLAGGED.value
if is_hazard and not is_flagged:
return
wrong_flags = sum(
1
for r in range(self.height)
for c in range(self.width)
if self._visible[r][c]["state"] == CellState.FLAGGED.value
and self._hidden[r][c]["type"] != CellType.HAZARD.value
)
if wrong_flags == 0:
self.game_over = True
self.won = True
def _judge_flag_solution(self, flagged: list[list[int]]) -> dict:
actual_hazards: set[tuple[int, int]] = set()
for r in range(self.height):
for c in range(self.width):
if self._hidden[r][c]["type"] == CellType.HAZARD.value:
actual_hazards.add((r, c))
submitted: set[tuple[int, int]] = {(p[0], p[1]) for p in flagged}
for r in range(self.height):
for c in range(self.width):
if self._visible[r][c]["state"] == CellState.FLAGGED.value:
submitted.add((r, c))
correct = submitted & actual_hazards
missed = actual_hazards - submitted
wrong = submitted - actual_hazards
precision = len(correct) / len(submitted) if submitted else 0.0
recall = len(correct) / len(actual_hazards) if actual_hazards else 1.0
self.won = len(missed) == 0 and len(wrong) == 0
return {
"correct": self.won,
"hazards_found": len(correct),
"hazards_total": len(actual_hazards),
"missed": len(missed),
"wrong_flags": len(wrong),
"precision": round(precision, 3),
"recall": round(recall, 3),
}
def _judge_safe_solution(self, safe_positions: list[list[int]]) -> dict:
actual_safe: set[tuple[int, int]] = set()
for r in range(self.height):
for c in range(self.width):
if self._hidden[r][c]["type"] != CellType.HAZARD.value:
actual_safe.add((r, c))
submitted: set[tuple[int, int]] = {(p[0], p[1]) for p in safe_positions}
correct = submitted & actual_safe
false_safe = submitted - actual_safe
missed_safe = actual_safe - submitted
precision = len(correct) / len(submitted) if submitted else 0.0
recall = len(correct) / len(actual_safe) if actual_safe else 1.0
self.won = len(false_safe) == 0 and len(missed_safe) == 0
return {
"correct": self.won,
"safe_found": len(correct),
"safe_total": len(actual_safe),
"false_safe": len(false_safe),
"missed_safe": len(missed_safe),
"precision": round(precision, 3),
"recall": round(recall, 3),
}