Spaces:
Sleeping
Sleeping
| """ | |
| Visual Memory per-step reward transform. | |
| Extends StepRewardTransform with game-aware scoring. Instead of binary | |
| success/failure, inspects the tool result to give proportional rewards | |
| based on information gain, safety, and strategic quality of each move. | |
| Used when: --reward-mode openenv | |
| Scoring by tool: | |
| reveal_cell: | |
| Safe reveal (signal/key/empty) β +0.15 | |
| Hazard hit β -0.40 | |
| Error (already revealed, etc.) β -0.10 | |
| flag_cell: | |
| Successful flag β +0.20 | |
| Error (already flagged, etc.) β -0.10 | |
| unflag_cell: | |
| Successful unflag β +0.05 (correcting a mistake is neutral-positive) | |
| Error β -0.10 | |
| submit_solution: | |
| Correct (perfect) β +1.0 | |
| Partial (precision*recall > 0) β +0.3 * F1 | |
| Wrong (zero overlap) β -0.50 | |
| recall_log / get_action_history: | |
| Success β +0.10 (evidence gathering) | |
| inspect_region: | |
| Success β +0.08 | |
| Error β -0.10 | |
| get_board_view / get_status / get_progress_stats: | |
| Success β +0.05 (observation, low cost) | |
| move_viewport: | |
| Success β +0.10 (exploration in fog scenarios) | |
| Error β -0.10 | |
| load_scenario / reset_scenario / list_scenarios / get_session_info: | |
| Always β +0.0 (session management, neutral) | |
| Distractor traps (auto_solve / peek_hidden_cell / undo_last_action): | |
| Always β -0.25 (models must learn to avoid) | |
| """ | |
| import json | |
| from openenv.core.env_server.mcp_types import CallToolObservation | |
| from openenv.core.env_server.types import Observation | |
| from .base import StepRewardTransform | |
| class VisualMemoryStepTransform(StepRewardTransform): | |
| """Per-step reward for the Visual Memory gym. | |
| Each tool call gets a reward based on its outcome. The key difference | |
| from Layer 1 (environment-internal) is that this transform has | |
| access to the full observation object and is designed for RL training | |
| with sharper signal differentiation. | |
| """ | |
| def _compute_reward(self, observation: Observation) -> float: | |
| if not isinstance(observation, CallToolObservation): | |
| return 0.0 | |
| if observation.error is not None: | |
| return -0.5 | |
| tool_name = getattr(observation, "tool_name", "") or "" | |
| result = self._extract_result(observation.result) | |
| if tool_name == "reveal_cell": | |
| return self._score_reveal(result) | |
| if tool_name == "flag_cell": | |
| return self._score_flag(result) | |
| if tool_name == "unflag_cell": | |
| return 0.05 if not self._is_error(result) else -0.10 | |
| if tool_name == "submit_solution": | |
| return self._score_submission(result) | |
| if tool_name in ("recall_log", "get_action_history"): | |
| return 0.10 if not self._is_error(result) else 0.0 | |
| if tool_name == "inspect_region": | |
| return 0.08 if not self._is_error(result) else -0.10 | |
| if tool_name in ("get_board_view", "get_status", "get_progress_stats"): | |
| return 0.05 if not self._is_error(result) else 0.0 | |
| if tool_name == "move_viewport": | |
| return 0.10 if not self._is_error(result) else -0.10 | |
| if tool_name in ("load_scenario", "reset_scenario", "list_scenarios", "get_session_info"): | |
| return 0.0 | |
| if tool_name in ("auto_solve", "peek_hidden_cell", "undo_last_action"): | |
| return -0.25 | |
| return 0.0 | |
| def _score_reveal(self, result: dict) -> float: | |
| if not isinstance(result, dict): | |
| return -0.10 | |
| if self._is_error(result): | |
| return -0.10 | |
| if result.get("hazard_hit"): | |
| return -0.40 | |
| return 0.15 | |
| def _score_flag(self, result: dict) -> float: | |
| if not isinstance(result, dict): | |
| return -0.10 | |
| if self._is_error(result): | |
| return -0.10 | |
| if result.get("flagged"): | |
| return 0.20 | |
| return 0.0 | |
| def _score_submission(self, result: dict) -> float: | |
| if not isinstance(result, dict): | |
| return -0.50 | |
| if self._is_error(result): | |
| return -0.50 | |
| if result.get("correct") is True: | |
| return 1.0 | |
| precision = result.get("precision", 0.0) | |
| recall = result.get("recall", 0.0) | |
| if precision + recall > 0: | |
| f1 = 2 * precision * recall / (precision + recall) | |
| return 0.3 * f1 | |
| keys_collected = result.get("keys_collected", 0) | |
| keys_required = result.get("keys_required", 1) | |
| if keys_required > 0 and keys_collected > 0: | |
| return 0.3 * (keys_collected / keys_required) | |
| return -0.50 | |
| def _is_error(result) -> bool: | |
| if isinstance(result, dict): | |
| return "error" in result | |
| return False | |
| def _extract_result(result): | |
| if hasattr(result, "data"): | |
| result = result.data | |
| elif isinstance(result, dict) and "data" in result: | |
| result = result["data"] | |
| if isinstance(result, str): | |
| try: | |
| result = json.loads(result) | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| return result | |