Spaces:
Sleeping
Sleeping
File size: 5,476 Bytes
15503f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
Visual Memory per-step reward transform.
Extends StepRewardTransform with game-aware scoring. Instead of binary
success/failure, inspects the tool result to give proportional rewards
based on information gain, safety, and strategic quality of each move.
Used when: --reward-mode openenv
Scoring by tool:
reveal_cell:
Safe reveal (signal/key/empty) β +0.15
Hazard hit β -0.40
Error (already revealed, etc.) β -0.10
flag_cell:
Successful flag β +0.20
Error (already flagged, etc.) β -0.10
unflag_cell:
Successful unflag β +0.05 (correcting a mistake is neutral-positive)
Error β -0.10
submit_solution:
Correct (perfect) β +1.0
Partial (precision*recall > 0) β +0.3 * F1
Wrong (zero overlap) β -0.50
recall_log / get_action_history:
Success β +0.10 (evidence gathering)
inspect_region:
Success β +0.08
Error β -0.10
get_board_view / get_status / get_progress_stats:
Success β +0.05 (observation, low cost)
move_viewport:
Success β +0.10 (exploration in fog scenarios)
Error β -0.10
load_scenario / reset_scenario / list_scenarios / get_session_info:
Always β +0.0 (session management, neutral)
Distractor traps (auto_solve / peek_hidden_cell / undo_last_action):
Always β -0.25 (models must learn to avoid)
"""
import json
from openenv.core.env_server.mcp_types import CallToolObservation
from openenv.core.env_server.types import Observation
from .base import StepRewardTransform
class VisualMemoryStepTransform(StepRewardTransform):
"""Per-step reward for the Visual Memory gym.
Each tool call gets a reward based on its outcome. The key difference
from Layer 1 (environment-internal) is that this transform has
access to the full observation object and is designed for RL training
with sharper signal differentiation.
"""
def _compute_reward(self, observation: Observation) -> float:
if not isinstance(observation, CallToolObservation):
return 0.0
if observation.error is not None:
return -0.5
tool_name = getattr(observation, "tool_name", "") or ""
result = self._extract_result(observation.result)
if tool_name == "reveal_cell":
return self._score_reveal(result)
if tool_name == "flag_cell":
return self._score_flag(result)
if tool_name == "unflag_cell":
return 0.05 if not self._is_error(result) else -0.10
if tool_name == "submit_solution":
return self._score_submission(result)
if tool_name in ("recall_log", "get_action_history"):
return 0.10 if not self._is_error(result) else 0.0
if tool_name == "inspect_region":
return 0.08 if not self._is_error(result) else -0.10
if tool_name in ("get_board_view", "get_status", "get_progress_stats"):
return 0.05 if not self._is_error(result) else 0.0
if tool_name == "move_viewport":
return 0.10 if not self._is_error(result) else -0.10
if tool_name in ("load_scenario", "reset_scenario", "list_scenarios", "get_session_info"):
return 0.0
if tool_name in ("auto_solve", "peek_hidden_cell", "undo_last_action"):
return -0.25
return 0.0
def _score_reveal(self, result: dict) -> float:
if not isinstance(result, dict):
return -0.10
if self._is_error(result):
return -0.10
if result.get("hazard_hit"):
return -0.40
return 0.15
def _score_flag(self, result: dict) -> float:
if not isinstance(result, dict):
return -0.10
if self._is_error(result):
return -0.10
if result.get("flagged"):
return 0.20
return 0.0
def _score_submission(self, result: dict) -> float:
if not isinstance(result, dict):
return -0.50
if self._is_error(result):
return -0.50
if result.get("correct") is True:
return 1.0
precision = result.get("precision", 0.0)
recall = result.get("recall", 0.0)
if precision + recall > 0:
f1 = 2 * precision * recall / (precision + recall)
return 0.3 * f1
keys_collected = result.get("keys_collected", 0)
keys_required = result.get("keys_required", 1)
if keys_required > 0 and keys_collected > 0:
return 0.3 * (keys_collected / keys_required)
return -0.50
@staticmethod
def _is_error(result) -> bool:
if isinstance(result, dict):
return "error" in result
return False
@staticmethod
def _extract_result(result):
if hasattr(result, "data"):
result = result.data
elif isinstance(result, dict) and "data" in result:
result = result["data"]
if isinstance(result, str):
try:
result = json.loads(result)
except (json.JSONDecodeError, TypeError):
pass
return result
|