Spaces:
Sleeping
Sleeping
| """Hidden-state game engine for Visual Memory Gym. | |
| Manages in-memory board state, hidden cell contents, move validation, | |
| and win/loss conditions across four task families: | |
| 1. hidden_grid β deduce hazard locations from signal clues | |
| 2. pattern_memory β recall briefly-shown cell contents | |
| 3. distractor_search β identify targets among visually similar decoys | |
| 4. fog_of_war β plan under limited viewport radius | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| from enum import Enum | |
| from typing import Any | |
| import numpy as np | |
| from pydantic import BaseModel, Field | |
| class CellType(str, Enum): | |
| EMPTY = "empty" | |
| HAZARD = "hazard" | |
| SIGNAL = "signal" | |
| KEY = "key" | |
| DECOY = "decoy" | |
| GOAL = "goal" | |
| class CellState(str, Enum): | |
| HIDDEN = "hidden" | |
| REVEALED = "revealed" | |
| FLAGGED = "flagged" | |
| FADED = "faded" | |
| class ScenarioType(str, Enum): | |
| HIDDEN_GRID = "hidden_grid" | |
| PATTERN_MEMORY = "pattern_memory" | |
| DISTRACTOR_SEARCH = "distractor_search" | |
| FOG_OF_WAR = "fog_of_war" | |
| class SignalMode(str, Enum): | |
| COUNT = "count" | |
| DIRECTIONAL = "directional" | |
| RANGE = "range" | |
| PARTIAL = "partial" | |
| class WinCondition(str, Enum): | |
| FLAG_ALL_HAZARDS = "flag_all_hazards" | |
| COLLECT_KEYS = "collect_keys" | |
| IDENTIFY_SAFE = "identify_safe_cells" | |
| REACH_GOAL = "reach_goal" | |
| class BoardState(BaseModel): | |
| """Serializable snapshot of the game state (visible portion only).""" | |
| session_id: str = "" | |
| scenario_id: str = "" | |
| scenario_type: str = "hidden_grid" | |
| step_count: int = 0 | |
| board_width: int = 0 | |
| board_height: int = 0 | |
| visible_cells: list[list[dict]] = Field(default_factory=list) | |
| discovered_signals: list[dict] = Field(default_factory=list) | |
| memory_events: list[dict] = Field(default_factory=list) | |
| game_over: bool = False | |
| won: bool = False | |
| flags_remaining: int = 0 | |
| cells_revealed: int = 0 | |
| hazard_hits: int = 0 | |
| keys_collected: int = 0 | |
| max_steps: int = 50 | |
| NEIGHBOR_OFFSETS = [ | |
| (-1, -1), (-1, 0), (-1, 1), | |
| (0, -1), (0, 1), | |
| (1, -1), (1, 0), (1, 1), | |
| ] | |
| DIRECTION_NAMES = { | |
| (-1, -1): "NW", (-1, 0): "N", (-1, 1): "NE", | |
| (0, -1): "W", (0, 1): "E", | |
| (1, -1): "SW", (1, 0): "S", (1, 1): "SE", | |
| } | |
| class GameEngine: | |
| """In-memory game engine for the Visual Memory gym. | |
| Deterministic given a seed. All state lives in Python memory. | |
| """ | |
| def __init__(self, scenario: dict, seed: int | None = None): | |
| self.scenario_id: str = scenario["scenario_id"] | |
| self.scenario_type = ScenarioType(scenario.get("type", "hidden_grid")) | |
| self.width: int = scenario["board_width"] | |
| self.height: int = scenario["board_height"] | |
| self.max_steps: int = scenario.get("max_steps", 50) | |
| self.max_hazard_reveals: int = scenario.get("max_hazard_reveals", 3) | |
| self.signal_mode = SignalMode(scenario.get("signal_mode", "count")) | |
| self.win_condition = WinCondition( | |
| scenario.get("win_condition", {}).get("type", "flag_all_hazards") | |
| ) | |
| resolved_seed = seed if seed is not None else scenario.get("seed", 42) | |
| self._rng = np.random.default_rng(resolved_seed) | |
| self.step_count: int = 0 | |
| self.hazard_hits: int = 0 | |
| self.keys_collected: int = 0 | |
| self.cells_revealed: int = 0 | |
| self.game_over: bool = False | |
| self.won: bool = False | |
| if "layout" in scenario: | |
| self._hidden = self._load_explicit_layout(scenario["layout"]) | |
| else: | |
| self._hidden = self._generate_board(scenario) | |
| self._visible: list[list[dict]] = [ | |
| [{"state": CellState.HIDDEN.value, "content": None} for _ in range(self.width)] | |
| for _ in range(self.height) | |
| ] | |
| total_hazards = sum( | |
| 1 | |
| for r in range(self.height) | |
| for c in range(self.width) | |
| if self._hidden[r][c]["type"] == CellType.HAZARD.value | |
| ) | |
| self.total_flags: int = scenario.get("flags_count", total_hazards + 3) | |
| self.flags_placed: int = 0 | |
| self.total_keys: int = sum( | |
| 1 | |
| for r in range(self.height) | |
| for c in range(self.width) | |
| if self._hidden[r][c]["type"] == CellType.KEY.value | |
| ) | |
| self._discovered_signals: list[dict] = [] | |
| self._memory_events: list[dict] = [] | |
| self._action_log: list[dict] = [] | |
| self._viewport_center: list[int] | None = scenario.get("start_position") | |
| self._viewport_radius: int | None = scenario.get("viewport_radius") | |
| self._flash_cells: list[list[int]] = scenario.get("flash_cells", []) | |
| self._flash_until_step: int = scenario.get("flash_until_step", 0) | |
| if self.scenario_type == ScenarioType.PATTERN_MEMORY and self._flash_cells: | |
| for rc in self._flash_cells: | |
| r, c = rc[0], rc[1] | |
| cell = self._hidden[r][c] | |
| self._visible[r][c] = { | |
| "state": CellState.REVEALED.value, | |
| "content": copy.deepcopy(cell), | |
| } | |
| self._memory_events.append({ | |
| "step": 0, | |
| "event": "flash_shown", | |
| "row": r, | |
| "col": c, | |
| "content": copy.deepcopy(cell), | |
| }) | |
| # βββ Board Generation βββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_explicit_layout(self, layout: list[list[dict]]) -> list[list[dict]]: | |
| board: list[list[dict]] = [] | |
| for row_data in layout: | |
| row: list[dict] = [] | |
| for cell in row_data: | |
| row.append({ | |
| "type": cell.get("type", CellType.EMPTY.value), | |
| "value": cell.get("value"), | |
| "properties": cell.get("properties", {}), | |
| }) | |
| board.append(row) | |
| return board | |
| def _generate_board(self, scenario: dict) -> list[list[dict]]: | |
| hazard_count = scenario.get("hazard_count", 10) | |
| key_count = scenario.get("key_count", 0) | |
| decoy_count = scenario.get("decoy_count", 0) | |
| goal_count = 1 if self.win_condition == WinCondition.REACH_GOAL else 0 | |
| total_cells = self.width * self.height | |
| total_special = hazard_count + key_count + decoy_count + goal_count | |
| if total_special > total_cells: | |
| raise ValueError( | |
| f"Cannot place {total_special} special cells on a " | |
| f"{self.width}x{self.height} board ({total_cells} cells)" | |
| ) | |
| positions = self._rng.permutation(total_cells) | |
| board: list[list[dict]] = [ | |
| [{"type": CellType.EMPTY.value, "value": None, "properties": {}} for _ in range(self.width)] | |
| for _ in range(self.height) | |
| ] | |
| idx = 0 | |
| for _ in range(hazard_count): | |
| r, c = divmod(int(positions[idx]), self.width) | |
| board[r][c] = {"type": CellType.HAZARD.value, "value": None, "properties": {}} | |
| idx += 1 | |
| for i in range(key_count): | |
| r, c = divmod(int(positions[idx]), self.width) | |
| board[r][c] = {"type": CellType.KEY.value, "value": f"key_{i}", "properties": {}} | |
| idx += 1 | |
| for i in range(decoy_count): | |
| r, c = divmod(int(positions[idx]), self.width) | |
| board[r][c] = {"type": CellType.DECOY.value, "value": f"decoy_{i}", "properties": {}} | |
| idx += 1 | |
| if goal_count: | |
| r, c = divmod(int(positions[idx]), self.width) | |
| board[r][c] = {"type": CellType.GOAL.value, "value": None, "properties": {}} | |
| idx += 1 | |
| self._compute_signals(board) | |
| return board | |
| def _compute_signals(self, board: list[list[dict]]) -> None: | |
| for r in range(self.height): | |
| for c in range(self.width): | |
| if board[r][c]["type"] != CellType.EMPTY.value: | |
| continue | |
| if self.signal_mode == SignalMode.COUNT: | |
| count = self._count_adjacent_hazards(board, r, c) | |
| if count > 0: | |
| board[r][c] = { | |
| "type": CellType.SIGNAL.value, | |
| "value": count, | |
| "properties": {"mode": "count"}, | |
| } | |
| elif self.signal_mode == SignalMode.DIRECTIONAL: | |
| directions = self._get_hazard_directions(board, r, c) | |
| if directions: | |
| board[r][c] = { | |
| "type": CellType.SIGNAL.value, | |
| "value": directions, | |
| "properties": {"mode": "directional"}, | |
| } | |
| elif self.signal_mode == SignalMode.RANGE: | |
| count = self._count_adjacent_hazards(board, r, c) | |
| if count > 0: | |
| noise = int(self._rng.integers(0, 2)) | |
| low = max(0, count - noise) | |
| high = count + int(self._rng.integers(0, 2)) | |
| board[r][c] = { | |
| "type": CellType.SIGNAL.value, | |
| "value": {"min": low, "max": high}, | |
| "properties": {"mode": "range"}, | |
| } | |
| elif self.signal_mode == SignalMode.PARTIAL: | |
| directions = self._get_hazard_directions(board, r, c) | |
| if directions: | |
| shown = max(1, len(directions) // 2) | |
| indices = self._rng.choice( | |
| len(directions), size=shown, replace=False | |
| ) | |
| subset = [directions[i] for i in sorted(indices)] | |
| board[r][c] = { | |
| "type": CellType.SIGNAL.value, | |
| "value": subset, | |
| "properties": { | |
| "mode": "partial", | |
| "total_hint": len(directions), | |
| }, | |
| } | |
| def _count_adjacent_hazards(self, board: list[list[dict]], r: int, c: int) -> int: | |
| count = 0 | |
| for dr, dc in NEIGHBOR_OFFSETS: | |
| nr, nc = r + dr, c + dc | |
| if 0 <= nr < self.height and 0 <= nc < self.width: | |
| if board[nr][nc]["type"] == CellType.HAZARD.value: | |
| count += 1 | |
| return count | |
| def _get_hazard_directions(self, board: list[list[dict]], r: int, c: int) -> list[str]: | |
| dirs: list[str] = [] | |
| for (dr, dc), name in DIRECTION_NAMES.items(): | |
| nr, nc = r + dr, c + dc | |
| if 0 <= nr < self.height and 0 <= nc < self.width: | |
| if board[nr][nc]["type"] == CellType.HAZARD.value: | |
| dirs.append(name) | |
| return dirs | |
| # βββ Pattern Memory Phase βββββββββββββββββββββββββββββββββββββββ | |
| def _tick_pattern_memory(self) -> None: | |
| if self.scenario_type != ScenarioType.PATTERN_MEMORY: | |
| return | |
| if self.step_count != self._flash_until_step: | |
| return | |
| for rc in self._flash_cells: | |
| r, c = rc[0], rc[1] | |
| if self._visible[r][c]["state"] == CellState.REVEALED.value: | |
| self._visible[r][c] = {"state": CellState.FADED.value, "content": None} | |
| self._memory_events.append({ | |
| "step": self.step_count, | |
| "event": "flash_faded", | |
| "row": r, | |
| "col": c, | |
| }) | |
| # βββ Core Actions βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reveal_cell(self, row: int, col: int) -> dict: | |
| if self.game_over: | |
| return {"error": "Game is already over.", "row": row, "col": col} | |
| if not self._in_bounds(row, col): | |
| return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col} | |
| vis = self._visible[row][col] | |
| if vis["state"] in (CellState.REVEALED.value, CellState.FLAGGED.value): | |
| return { | |
| "error": f"Cell ({row},{col}) is already {vis['state']}.", | |
| "row": row, | |
| "col": col, | |
| } | |
| if self._viewport_radius is not None and self._viewport_center is not None: | |
| vr, vc = self._viewport_center | |
| if abs(row - vr) > self._viewport_radius or abs(col - vc) > self._viewport_radius: | |
| return { | |
| "error": f"({row},{col}) is outside your current viewport.", | |
| "row": row, | |
| "col": col, | |
| } | |
| self.step_count += 1 | |
| self._tick_pattern_memory() | |
| hidden = self._hidden[row][col] | |
| cell_type = hidden["type"] | |
| self._visible[row][col] = { | |
| "state": CellState.REVEALED.value, | |
| "content": copy.deepcopy(hidden), | |
| } | |
| self.cells_revealed += 1 | |
| result: dict[str, Any] = { | |
| "row": row, | |
| "col": col, | |
| "type": cell_type, | |
| "value": hidden.get("value"), | |
| "properties": hidden.get("properties", {}), | |
| } | |
| if cell_type == CellType.SIGNAL.value: | |
| self._discovered_signals.append(result) | |
| if cell_type == CellType.HAZARD.value: | |
| self.hazard_hits += 1 | |
| result["hazard_hit"] = True | |
| if self.hazard_hits >= self.max_hazard_reveals: | |
| self.game_over = True | |
| self.won = False | |
| result["game_over"] = True | |
| result["message"] = "Too many hazards revealed. Game over." | |
| if cell_type == CellType.KEY.value: | |
| self.keys_collected += 1 | |
| result["key_collected"] = True | |
| if ( | |
| self.win_condition == WinCondition.COLLECT_KEYS | |
| and self.keys_collected >= self.total_keys | |
| ): | |
| self.game_over = True | |
| self.won = True | |
| result["game_over"] = True | |
| result["message"] = "All keys collected. You win!" | |
| if cell_type == CellType.GOAL.value and self.win_condition == WinCondition.REACH_GOAL: | |
| self.game_over = True | |
| self.won = True | |
| result["game_over"] = True | |
| result["message"] = "Goal reached. You win!" | |
| if self.step_count >= self.max_steps and not self.game_over: | |
| self.game_over = True | |
| self.won = False | |
| result["game_over"] = True | |
| result["message"] = "Max steps exceeded. Game over." | |
| self._action_log.append({ | |
| "action": "reveal", | |
| "row": row, | |
| "col": col, | |
| "step": self.step_count, | |
| "result_type": cell_type, | |
| }) | |
| return result | |
| def flag_cell(self, row: int, col: int) -> dict: | |
| if self.game_over: | |
| return {"error": "Game is already over.", "row": row, "col": col} | |
| if not self._in_bounds(row, col): | |
| return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col} | |
| vis = self._visible[row][col] | |
| if vis["state"] == CellState.REVEALED.value: | |
| return {"error": f"Cell ({row},{col}) is already revealed; cannot flag.", "row": row, "col": col} | |
| if vis["state"] == CellState.FLAGGED.value: | |
| return {"error": f"Cell ({row},{col}) is already flagged.", "row": row, "col": col} | |
| if self.flags_placed >= self.total_flags: | |
| return {"error": "No flags remaining.", "row": row, "col": col} | |
| self.step_count += 1 | |
| self._tick_pattern_memory() | |
| self._visible[row][col] = {"state": CellState.FLAGGED.value, "content": None} | |
| self.flags_placed += 1 | |
| self._action_log.append({"action": "flag", "row": row, "col": col, "step": self.step_count}) | |
| self._check_flag_win() | |
| result: dict[str, Any] = { | |
| "row": row, | |
| "col": col, | |
| "flagged": True, | |
| "flags_remaining": self.total_flags - self.flags_placed, | |
| } | |
| if self.game_over and self.won: | |
| result["game_over"] = True | |
| result["message"] = "All hazards correctly flagged. You win!" | |
| if self.step_count >= self.max_steps and not self.game_over: | |
| self.game_over = True | |
| self.won = False | |
| result["game_over"] = True | |
| result["message"] = "Max steps exceeded. Game over." | |
| return result | |
| def unflag_cell(self, row: int, col: int) -> dict: | |
| if self.game_over: | |
| return {"error": "Game is already over.", "row": row, "col": col} | |
| if not self._in_bounds(row, col): | |
| return {"error": f"({row},{col}) is out of bounds.", "row": row, "col": col} | |
| if self._visible[row][col]["state"] != CellState.FLAGGED.value: | |
| return {"error": f"Cell ({row},{col}) is not flagged.", "row": row, "col": col} | |
| self.step_count += 1 | |
| self._tick_pattern_memory() | |
| self._visible[row][col] = {"state": CellState.HIDDEN.value, "content": None} | |
| self.flags_placed -= 1 | |
| self._action_log.append({"action": "unflag", "row": row, "col": col, "step": self.step_count}) | |
| result: dict[str, Any] = { | |
| "row": row, | |
| "col": col, | |
| "unflagged": True, | |
| "flags_remaining": self.total_flags - self.flags_placed, | |
| } | |
| if self.step_count >= self.max_steps and not self.game_over: | |
| self.game_over = True | |
| self.won = False | |
| result["game_over"] = True | |
| result["message"] = "Max steps exceeded. Game over." | |
| return result | |
| def move_viewport(self, row: int, col: int) -> dict: | |
| if self.scenario_type != ScenarioType.FOG_OF_WAR: | |
| return {"error": "move_viewport is only available in fog_of_war scenarios."} | |
| if self.game_over: | |
| return {"error": "Game is already over."} | |
| if not self._in_bounds(row, col): | |
| return {"error": f"({row},{col}) is out of bounds."} | |
| self.step_count += 1 | |
| self._tick_pattern_memory() | |
| self._viewport_center = [row, col] | |
| self._action_log.append({ | |
| "action": "move_viewport", | |
| "row": row, | |
| "col": col, | |
| "step": self.step_count, | |
| }) | |
| if self.step_count >= self.max_steps and not self.game_over: | |
| self.game_over = True | |
| self.won = False | |
| return { | |
| "viewport_center": [row, col], | |
| "viewport_radius": self._viewport_radius, | |
| "visible_area": self._get_viewport_bounds(), | |
| } | |
| def submit_solution( | |
| self, | |
| flagged_positions: list[list[int]] | None = None, | |
| safe_positions: list[list[int]] | None = None, | |
| ) -> dict: | |
| if self.game_over: | |
| return {"error": "Game is already over."} | |
| self.step_count += 1 | |
| self.game_over = True | |
| if self.win_condition == WinCondition.FLAG_ALL_HAZARDS: | |
| return self._judge_flag_solution(flagged_positions or []) | |
| elif self.win_condition == WinCondition.IDENTIFY_SAFE: | |
| return self._judge_safe_solution(safe_positions or []) | |
| elif self.win_condition == WinCondition.COLLECT_KEYS: | |
| success = self.keys_collected >= self.total_keys | |
| self.won = success | |
| return { | |
| "correct": success, | |
| "keys_collected": self.keys_collected, | |
| "keys_required": self.total_keys, | |
| } | |
| elif self.win_condition == WinCondition.REACH_GOAL: | |
| self.won = False | |
| return {"correct": False, "message": "Goal was not reached before submission."} | |
| return {"error": "Unknown win condition."} | |
| # βββ State Queries ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_visible_board(self) -> list[list[dict]]: | |
| if self._viewport_radius is None or self._viewport_center is None: | |
| return copy.deepcopy(self._visible) | |
| vr, vc = self._viewport_center | |
| rad = self._viewport_radius | |
| fog_board: list[list[dict]] = [ | |
| [{"state": "fog", "content": None} for _ in range(self.width)] | |
| for _ in range(self.height) | |
| ] | |
| for r in range(max(0, vr - rad), min(self.height, vr + rad + 1)): | |
| for c in range(max(0, vc - rad), min(self.width, vc + rad + 1)): | |
| fog_board[r][c] = copy.deepcopy(self._visible[r][c]) | |
| return fog_board | |
| def get_status(self) -> dict: | |
| return { | |
| "scenario_id": self.scenario_id, | |
| "scenario_type": self.scenario_type.value, | |
| "step_count": self.step_count, | |
| "max_steps": self.max_steps, | |
| "board_size": f"{self.width}x{self.height}", | |
| "cells_revealed": self.cells_revealed, | |
| "hazard_hits": self.hazard_hits, | |
| "max_hazard_reveals": self.max_hazard_reveals, | |
| "keys_collected": self.keys_collected, | |
| "total_keys": self.total_keys, | |
| "flags_placed": self.flags_placed, | |
| "flags_remaining": self.total_flags - self.flags_placed, | |
| "game_over": self.game_over, | |
| "won": self.won, | |
| "win_condition": self.win_condition.value, | |
| } | |
| def get_board_state(self, session_id: str = "") -> BoardState: | |
| return BoardState( | |
| session_id=session_id, | |
| scenario_id=self.scenario_id, | |
| scenario_type=self.scenario_type.value, | |
| step_count=self.step_count, | |
| board_width=self.width, | |
| board_height=self.height, | |
| visible_cells=self.get_visible_board(), | |
| discovered_signals=copy.deepcopy(self._discovered_signals), | |
| memory_events=copy.deepcopy(self._memory_events), | |
| game_over=self.game_over, | |
| won=self.won, | |
| flags_remaining=self.total_flags - self.flags_placed, | |
| cells_revealed=self.cells_revealed, | |
| hazard_hits=self.hazard_hits, | |
| keys_collected=self.keys_collected, | |
| max_steps=self.max_steps, | |
| ) | |
| def get_hidden_board(self) -> list[list[dict]]: | |
| """Full hidden board β for reward computation only, never sent to agent.""" | |
| return copy.deepcopy(self._hidden) | |
| def get_action_log(self) -> list[dict]: | |
| return copy.deepcopy(self._action_log) | |
| # βββ Internal Helpers βββββββββββββββββββββββββββββββββββββββββββ | |
| def _in_bounds(self, row: int, col: int) -> bool: | |
| return 0 <= row < self.height and 0 <= col < self.width | |
| def _get_viewport_bounds(self) -> dict: | |
| if self._viewport_center is None or self._viewport_radius is None: | |
| return { | |
| "r_min": 0, | |
| "r_max": self.height - 1, | |
| "c_min": 0, | |
| "c_max": self.width - 1, | |
| } | |
| vr, vc = self._viewport_center | |
| rad = self._viewport_radius | |
| return { | |
| "r_min": max(0, vr - rad), | |
| "r_max": min(self.height - 1, vr + rad), | |
| "c_min": max(0, vc - rad), | |
| "c_max": min(self.width - 1, vc + rad), | |
| } | |
| def _check_flag_win(self) -> None: | |
| if self.win_condition != WinCondition.FLAG_ALL_HAZARDS: | |
| return | |
| for r in range(self.height): | |
| for c in range(self.width): | |
| is_hazard = self._hidden[r][c]["type"] == CellType.HAZARD.value | |
| is_flagged = self._visible[r][c]["state"] == CellState.FLAGGED.value | |
| if is_hazard and not is_flagged: | |
| return | |
| wrong_flags = sum( | |
| 1 | |
| for r in range(self.height) | |
| for c in range(self.width) | |
| if self._visible[r][c]["state"] == CellState.FLAGGED.value | |
| and self._hidden[r][c]["type"] != CellType.HAZARD.value | |
| ) | |
| if wrong_flags == 0: | |
| self.game_over = True | |
| self.won = True | |
| def _judge_flag_solution(self, flagged: list[list[int]]) -> dict: | |
| actual_hazards: set[tuple[int, int]] = set() | |
| for r in range(self.height): | |
| for c in range(self.width): | |
| if self._hidden[r][c]["type"] == CellType.HAZARD.value: | |
| actual_hazards.add((r, c)) | |
| submitted: set[tuple[int, int]] = {(p[0], p[1]) for p in flagged} | |
| for r in range(self.height): | |
| for c in range(self.width): | |
| if self._visible[r][c]["state"] == CellState.FLAGGED.value: | |
| submitted.add((r, c)) | |
| correct = submitted & actual_hazards | |
| missed = actual_hazards - submitted | |
| wrong = submitted - actual_hazards | |
| precision = len(correct) / len(submitted) if submitted else 0.0 | |
| recall = len(correct) / len(actual_hazards) if actual_hazards else 1.0 | |
| self.won = len(missed) == 0 and len(wrong) == 0 | |
| return { | |
| "correct": self.won, | |
| "hazards_found": len(correct), | |
| "hazards_total": len(actual_hazards), | |
| "missed": len(missed), | |
| "wrong_flags": len(wrong), | |
| "precision": round(precision, 3), | |
| "recall": round(recall, 3), | |
| } | |
| def _judge_safe_solution(self, safe_positions: list[list[int]]) -> dict: | |
| actual_safe: set[tuple[int, int]] = set() | |
| for r in range(self.height): | |
| for c in range(self.width): | |
| if self._hidden[r][c]["type"] != CellType.HAZARD.value: | |
| actual_safe.add((r, c)) | |
| submitted: set[tuple[int, int]] = {(p[0], p[1]) for p in safe_positions} | |
| correct = submitted & actual_safe | |
| false_safe = submitted - actual_safe | |
| missed_safe = actual_safe - submitted | |
| precision = len(correct) / len(submitted) if submitted else 0.0 | |
| recall = len(correct) / len(actual_safe) if actual_safe else 1.0 | |
| self.won = len(false_safe) == 0 and len(missed_safe) == 0 | |
| return { | |
| "correct": self.won, | |
| "safe_found": len(correct), | |
| "safe_total": len(actual_safe), | |
| "false_safe": len(false_safe), | |
| "missed_safe": len(missed_safe), | |
| "precision": round(precision, 3), | |
| "recall": round(recall, 3), | |
| } | |