"""Hidden-state game engine for Visual Memory Gym. Manages in-memory board state, hidden cell contents, move validation, and win/loss conditions across four task families: 1. hidden_grid — deduce hazard locations from signal clues 2. pattern_memory — recall briefly-shown cell contents 3. distractor_search — identify targets among visually similar decoys 4. fog_of_war — plan under limited viewport radius """ from __future__ import annotations import copy from enum import Enum from typing import Any import numpy as np from pydantic import BaseModel, Field class CellType(str, Enum): EMPTY = "empty" HAZARD = "hazard" SIGNAL = "signal" KEY = "key" DECOY = "decoy" GOAL = "goal" class CellState(str, Enum): HIDDEN = "hidden" REVEALED = "revealed" FLAGGED = "flagged" FADED = "faded" class ScenarioType(str, Enum): HIDDEN_GRID = "hidden_grid" PATTERN_MEMORY = "pattern_memory" DISTRACTOR_SEARCH = "distractor_search" FOG_OF_WAR = "fog_of_war" class SignalMode(str, Enum): COUNT = "count" DIRECTIONAL = "directional" RANGE = "range" PARTIAL = "partial" class WinCondition(str, Enum): FLAG_ALL_HAZARDS = "flag_all_hazards" COLLECT_KEYS = "collect_keys" IDENTIFY_SAFE = "identify_safe_cells" REACH_GOAL = "reach_goal" class BoardState(BaseModel): """Serializable snapshot of the game state (visible portion only).""" session_id: str = "" scenario_id: str = "" scenario_type: str = "hidden_grid" step_count: int = 0 board_width: int = 0 board_height: int = 0 visible_cells: list[list[dict]] = Field(default_factory=list) discovered_signals: list[dict] = Field(default_factory=list) memory_events: list[dict] = Field(default_factory=list) game_over: bool = False won: bool = False flags_remaining: int = 0 cells_revealed: int = 0 hazard_hits: int = 0 keys_collected: int = 0 max_steps: int = 50 NEIGHBOR_OFFSETS = [ (-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1), ] DIRECTION_NAMES = { (-1, -1): "NW", (-1, 0): "N", (-1, 1): "NE", (0, -1): "W", (0, 1): "E", (1, -1): "SW", (1, 0): "S", (1, 1): "SE", } class GameEngine: """In-memory game engine for the Visual Memory gym. Deterministic given a seed. All state lives in Python memory. """ def __init__(self, scenario: dict, seed: int | None = None): self.scenario_id: str = scenario["scenario_id"] self.scenario_type = ScenarioType(scenario.get("type", "hidden_grid")) self.width: int = scenario["board_width"] self.height: int = scenario["board_height"] self.max_steps: int = scenario.get("max_steps", 50) self.max_hazard_reveals: int = scenario.get("max_hazard_reveals", 3) self.signal_mode = SignalMode(scenario.get("signal_mode", "count")) self.win_condition = WinCondition( scenario.get("win_condition", {}).get("type", "flag_all_hazards") ) resolved_seed = seed if seed is not None else scenario.get("seed", 42) self._rng = np.random.default_rng(resolved_seed) self.step_count: int = 0 self.hazard_hits: int = 0 self.keys_collected: int = 0 self.cells_revealed: int = 0 self.game_over: bool = False self.won: bool = False if "layout" in scenario: self._hidden = self._load_explicit_layout(scenario["layout"]) else: self._hidden = self._generate_board(scenario) self._visible: list[list[dict]] = [ [{"state": CellState.HIDDEN.value, "content": None} for _ in range(self.width)] for _ in range(self.height) ] total_hazards = sum( 1 for r in range(self.height) for c in range(self.width) if self._hidden[r][c]["type"] == CellType.HAZARD.value ) self.total_flags: int = scenario.get("flags_count", total_hazards + 3) self.flags_placed: int = 0 self.total_keys: int = sum( 1 for r in range(self.height) for c in range(self.width) if self._hidden[r][c]["type"] == CellType.KEY.value ) self._discovered_signals: list[dict] = [] self._memory_events: list[dict] = [] self._action_log: list[dict] = [] self._viewport_center: list[int] | None = scenario.get("start_position") self._viewport_radius: int | None = scenario.get("viewport_radius") self._flash_cells: list[list[int]] = scenario.get("flash_cells", []) self._flash_until_step: int = scenario.get("flash_until_step", 0) if self.scenario_type == ScenarioType.PATTERN_MEMORY and self._flash_cells: for rc in self._flash_cells: r, c = rc[0], rc[1] cell = self._hidden[r][c] self._visible[r][c] = { "state": CellState.REVEALED.value, "content": copy.deepcopy(cell), } self._memory_events.append({ "step": 0, "event": "flash_shown", "row": r, "col": c, "content": copy.deepcopy(cell), }) # ─── Board Generation ─────────────────────────────────────────── def _load_explicit_layout(self, layout: list[list[dict]]) -> list[list[dict]]: board: list[list[dict]] = [] for row_data in layout: row: list[dict] = [] for cell in row_data: row.append({ "type": cell.get("type", CellType.EMPTY.value), "value": cell.get("value"), "properties": cell.get("properties", {}), }) board.append(row) return board def _generate_board(self, scenario: dict) -> list[list[dict]]: hazard_count = scenario.get("hazard_count", 10) key_count = scenario.get("key_count", 0) decoy_count = scenario.get("decoy_count", 0) goal_count = 1 if self.win_condition == WinCondition.REACH_GOAL else 0 total_cells = self.width * self.height total_special = hazard_count + key_count + decoy_count + goal_count if total_special > total_cells: raise ValueError( f"Cannot place {total_special} special cells on a " f"{self.width}x{self.height} board ({total_cells} cells)" ) positions = self._rng.permutation(total_cells) board: list[list[dict]] = [ [{"type": CellType.EMPTY.value, "value": None, "properties": {}} for _ in range(self.width)] for _ in range(self.height) ] idx = 0 for _ in range(hazard_count): r, c = divmod(int(positions[idx]), self.width) board[r][c] = {"type": CellType.HAZARD.value, "value": None, "properties": {}} idx += 1 for i in range(key_count): r, c = divmod(int(positions[idx]), self.width) board[r][c] = {"type": CellType.KEY.value, "value": f"key_{i}", "properties": {}} idx += 1 for i in range(decoy_count): r, c = divmod(int(positions[idx]), self.width) board[r][c] = {"type": CellType.DECOY.value, "value": f"decoy_{i}", "properties": {}} idx += 1 if goal_count: r, c = divmod(int(positions[idx]), self.width) board[r][c] = {"type": CellType.GOAL.value, "value": None, "properties": {}} idx += 1 self._compute_signals(board) return board def _compute_signals(self, board: list[list[dict]]) -> None: for r in range(self.height): for c in range(self.width): if board[r][c]["type"] != CellType.EMPTY.value: continue if self.signal_mode == SignalMode.COUNT: count = self._count_adjacent_hazards(board, r, c) if count > 0: board[r][c] = { "type": CellType.SIGNAL.value, "value": count, "properties": {"mode": "count"}, } elif self.signal_mode == SignalMode.DIRECTIONAL: directions = self._get_hazard_directions(board, r, c) if directions: board[r][c] = { "type": CellType.SIGNAL.value, "value": directions, "properties": {"mode": "directional"}, } elif self.signal_mode == SignalMode.RANGE: count = self._count_adjacent_hazards(board, r, c) if count > 0: noise = int(self._rng.integers(0, 2)) low = max(0, count - noise) high = count + int(self._rng.integers(0, 2)) board[r][c] = { "type": CellType.SIGNAL.value, "value": {"min": low, "max": high}, "properties": {"mode": "range"}, } elif self.signal_mode == SignalMode.PARTIAL: directions = self._get_hazard_directions(board, r, c) if directions: shown = max(1, len(directions) // 2) indices = self._rng.choice( len(directions), size=shown, replace=False ) subset = [directions[i] for i in sorted(indices)] board[r][c] = { "type": CellType.SIGNAL.value, "value": subset, "properties": { "mode": "partial", "total_hint": len(directions), }, } def _count_adjacent_hazards(self, board: list[list[dict]], r: int, c: int) -> int: count = 0 for dr, dc in NEIGHBOR_OFFSETS: nr, nc = r + dr, c + dc if 0 <= nr < self.height and 0 <= nc < self.width: if board[nr][nc]["type"] == CellType.HAZARD.value: count += 1 return count def _get_hazard_directions(self, board: list[list[dict]], r: int, c: int) -> list[str]: dirs: list[str] = [] for (dr, dc), name in DIRECTION_NAMES.items(): nr, nc = r + dr, c + dc if 0 <= nr < self.height and 0 <= nc < self.width: if board[nr][nc]["type"] == CellType.HAZARD.value: dirs.append(name) return dirs # ─── Pattern Memory Phase ─────────────────────────────────────── def _tick_pattern_memory(self) -> None: if self.scenario_type != ScenarioType.PATTERN_MEMORY: return if self.step_count != self._flash_until_step: return for rc in self._flash_cells: r, c = rc[0], rc[1] if self._visible[r][c]["state"] == CellState.REVEALED.value: self._visible[r][c] = {"state": CellState.FADED.value, "content": None} self._memory_events.append({ "step": self.step_count, "event": "flash_faded", "row": r, "col": c, }) # ─── Core Actions ─────────────────────────────────────────────── def reveal_cell(self, row: int, col: int) -> dict: if self.game_over: return {"error": "Game is already over.", "row": row, "col": col} if not self._in_bounds(row, col): return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col} vis = self._visible[row][col] if vis["state"] in (CellState.REVEALED.value, CellState.FLAGGED.value): return { "error": f"Cell ({row},{col}) is already {vis['state']}.", "row": row, "col": col, } if self._viewport_radius is not None and self._viewport_center is not None: vr, vc = self._viewport_center if abs(row - vr) > self._viewport_radius or abs(col - vc) > self._viewport_radius: return { "error": f"({row},{col}) is outside your current viewport.", "row": row, "col": col, } self.step_count += 1 self._tick_pattern_memory() hidden = self._hidden[row][col] cell_type = hidden["type"] self._visible[row][col] = { "state": CellState.REVEALED.value, "content": copy.deepcopy(hidden), } self.cells_revealed += 1 result: dict[str, Any] = { "row": row, "col": col, "type": cell_type, "value": hidden.get("value"), "properties": hidden.get("properties", {}), } if cell_type == CellType.SIGNAL.value: self._discovered_signals.append(result) if cell_type == CellType.HAZARD.value: self.hazard_hits += 1 result["hazard_hit"] = True if self.hazard_hits >= self.max_hazard_reveals: self.game_over = True self.won = False result["game_over"] = True result["message"] = "Too many hazards revealed. Game over." if cell_type == CellType.KEY.value: self.keys_collected += 1 result["key_collected"] = True if ( self.win_condition == WinCondition.COLLECT_KEYS and self.keys_collected >= self.total_keys ): self.game_over = True self.won = True result["game_over"] = True result["message"] = "All keys collected. You win!" if cell_type == CellType.GOAL.value and self.win_condition == WinCondition.REACH_GOAL: self.game_over = True self.won = True result["game_over"] = True result["message"] = "Goal reached. You win!" if self.step_count >= self.max_steps and not self.game_over: self.game_over = True self.won = False result["game_over"] = True result["message"] = "Max steps exceeded. Game over." self._action_log.append({ "action": "reveal", "row": row, "col": col, "step": self.step_count, "result_type": cell_type, }) return result def flag_cell(self, row: int, col: int) -> dict: if self.game_over: return {"error": "Game is already over.", "row": row, "col": col} if not self._in_bounds(row, col): return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col} vis = self._visible[row][col] if vis["state"] == CellState.REVEALED.value: return {"error": f"Cell ({row},{col}) is already revealed; cannot flag.", "row": row, "col": col} if vis["state"] == CellState.FLAGGED.value: return {"error": f"Cell ({row},{col}) is already flagged.", "row": row, "col": col} if self.flags_placed >= self.total_flags: return {"error": "No flags remaining.", "row": row, "col": col} self.step_count += 1 self._tick_pattern_memory() self._visible[row][col] = {"state": CellState.FLAGGED.value, "content": None} self.flags_placed += 1 self._action_log.append({"action": "flag", "row": row, "col": col, "step": self.step_count}) self._check_flag_win() result: dict[str, Any] = { "row": row, "col": col, "flagged": True, "flags_remaining": self.total_flags - self.flags_placed, } if self.game_over and self.won: result["game_over"] = True result["message"] = "All hazards correctly flagged. You win!" if self.step_count >= self.max_steps and not self.game_over: self.game_over = True self.won = False result["game_over"] = True result["message"] = "Max steps exceeded. Game over." return result def unflag_cell(self, row: int, col: int) -> dict: if self.game_over: return {"error": "Game is already over.", "row": row, "col": col} if not self._in_bounds(row, col): return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col} if self._visible[row][col]["state"] != CellState.FLAGGED.value: return {"error": f"Cell ({row},{col}) is not flagged.", "row": row, "col": col} self.step_count += 1 self._tick_pattern_memory() self._visible[row][col] = {"state": CellState.HIDDEN.value, "content": None} self.flags_placed -= 1 self._action_log.append({"action": "unflag", "row": row, "col": col, "step": self.step_count}) result: dict[str, Any] = { "row": row, "col": col, "unflagged": True, "flags_remaining": self.total_flags - self.flags_placed, } if self.step_count >= self.max_steps and not self.game_over: self.game_over = True self.won = False result["game_over"] = True result["message"] = "Max steps exceeded. Game over." return result def move_viewport(self, row: int, col: int) -> dict: if self.scenario_type != ScenarioType.FOG_OF_WAR: return {"error": "move_viewport is only available in fog_of_war scenarios."} if self.game_over: return {"error": "Game is already over."} if not self._in_bounds(row, col): return {"error": f"({row},{col}) is out of bounds."} self.step_count += 1 self._tick_pattern_memory() self._viewport_center = [row, col] self._action_log.append({ "action": "move_viewport", "row": row, "col": col, "step": self.step_count, }) if self.step_count >= self.max_steps and not self.game_over: self.game_over = True self.won = False return { "viewport_center": [row, col], "viewport_radius": self._viewport_radius, "visible_area": self._get_viewport_bounds(), } def submit_solution( self, flagged_positions: list[list[int]] | None = None, safe_positions: list[list[int]] | None = None, ) -> dict: if self.game_over: return {"error": "Game is already over."} self.step_count += 1 self.game_over = True if self.win_condition == WinCondition.FLAG_ALL_HAZARDS: return self._judge_flag_solution(flagged_positions or []) elif self.win_condition == WinCondition.IDENTIFY_SAFE: return self._judge_safe_solution(safe_positions or []) elif self.win_condition == WinCondition.COLLECT_KEYS: success = self.keys_collected >= self.total_keys self.won = success return { "correct": success, "keys_collected": self.keys_collected, "keys_required": self.total_keys, } elif self.win_condition == WinCondition.REACH_GOAL: self.won = False return {"correct": False, "message": "Goal was not reached before submission."} return {"error": "Unknown win condition."} # ─── State Queries ────────────────────────────────────────────── def get_visible_board(self) -> list[list[dict]]: if self._viewport_radius is None or self._viewport_center is None: return copy.deepcopy(self._visible) vr, vc = self._viewport_center rad = self._viewport_radius fog_board: list[list[dict]] = [ [{"state": "fog", "content": None} for _ in range(self.width)] for _ in range(self.height) ] for r in range(max(0, vr - rad), min(self.height, vr + rad + 1)): for c in range(max(0, vc - rad), min(self.width, vc + rad + 1)): fog_board[r][c] = copy.deepcopy(self._visible[r][c]) return fog_board def get_status(self) -> dict: return { "scenario_id": self.scenario_id, "scenario_type": self.scenario_type.value, "step_count": self.step_count, "max_steps": self.max_steps, "board_size": f"{self.width}x{self.height}", "cells_revealed": self.cells_revealed, "hazard_hits": self.hazard_hits, "max_hazard_reveals": self.max_hazard_reveals, "keys_collected": self.keys_collected, "total_keys": self.total_keys, "flags_placed": self.flags_placed, "flags_remaining": self.total_flags - self.flags_placed, "game_over": self.game_over, "won": self.won, "win_condition": self.win_condition.value, } def get_board_state(self, session_id: str = "") -> BoardState: return BoardState( session_id=session_id, scenario_id=self.scenario_id, scenario_type=self.scenario_type.value, step_count=self.step_count, board_width=self.width, board_height=self.height, visible_cells=self.get_visible_board(), discovered_signals=copy.deepcopy(self._discovered_signals), memory_events=copy.deepcopy(self._memory_events), game_over=self.game_over, won=self.won, flags_remaining=self.total_flags - self.flags_placed, cells_revealed=self.cells_revealed, hazard_hits=self.hazard_hits, keys_collected=self.keys_collected, max_steps=self.max_steps, ) def get_hidden_board(self) -> list[list[dict]]: """Full hidden board — for reward computation only, never sent to agent.""" return copy.deepcopy(self._hidden) def get_action_log(self) -> list[dict]: return copy.deepcopy(self._action_log) # ─── Internal Helpers ─────────────────────────────────────────── def _in_bounds(self, row: int, col: int) -> bool: return 0 <= row < self.height and 0 <= col < self.width def _get_viewport_bounds(self) -> dict: if self._viewport_center is None or self._viewport_radius is None: return { "r_min": 0, "r_max": self.height - 1, "c_min": 0, "c_max": self.width - 1, } vr, vc = self._viewport_center rad = self._viewport_radius return { "r_min": max(0, vr - rad), "r_max": min(self.height - 1, vr + rad), "c_min": max(0, vc - rad), "c_max": min(self.width - 1, vc + rad), } def _check_flag_win(self) -> None: if self.win_condition != WinCondition.FLAG_ALL_HAZARDS: return for r in range(self.height): for c in range(self.width): is_hazard = self._hidden[r][c]["type"] == CellType.HAZARD.value is_flagged = self._visible[r][c]["state"] == CellState.FLAGGED.value if is_hazard and not is_flagged: return wrong_flags = sum( 1 for r in range(self.height) for c in range(self.width) if self._visible[r][c]["state"] == CellState.FLAGGED.value and self._hidden[r][c]["type"] != CellType.HAZARD.value ) if wrong_flags == 0: self.game_over = True self.won = True def _judge_flag_solution(self, flagged: list[list[int]]) -> dict: actual_hazards: set[tuple[int, int]] = set() for r in range(self.height): for c in range(self.width): if self._hidden[r][c]["type"] == CellType.HAZARD.value: actual_hazards.add((r, c)) submitted: set[tuple[int, int]] = {(p[0], p[1]) for p in flagged} for r in range(self.height): for c in range(self.width): if self._visible[r][c]["state"] == CellState.FLAGGED.value: submitted.add((r, c)) correct = submitted & actual_hazards missed = actual_hazards - submitted wrong = submitted - actual_hazards precision = len(correct) / len(submitted) if submitted else 0.0 recall = len(correct) / len(actual_hazards) if actual_hazards else 1.0 self.won = len(missed) == 0 and len(wrong) == 0 return { "correct": self.won, "hazards_found": len(correct), "hazards_total": len(actual_hazards), "missed": len(missed), "wrong_flags": len(wrong), "precision": round(precision, 3), "recall": round(recall, 3), } def _judge_safe_solution(self, safe_positions: list[list[int]]) -> dict: actual_safe: set[tuple[int, int]] = set() for r in range(self.height): for c in range(self.width): if self._hidden[r][c]["type"] != CellType.HAZARD.value: actual_safe.add((r, c)) submitted: set[tuple[int, int]] = {(p[0], p[1]) for p in safe_positions} correct = submitted & actual_safe false_safe = submitted - actual_safe missed_safe = actual_safe - submitted precision = len(correct) / len(submitted) if submitted else 0.0 recall = len(correct) / len(actual_safe) if actual_safe else 1.0 self.won = len(false_safe) == 0 and len(missed_safe) == 0 return { "correct": self.won, "safe_found": len(correct), "safe_total": len(actual_safe), "false_safe": len(false_safe), "missed_safe": len(missed_safe), "precision": round(precision, 3), "recall": round(recall, 3), }