visual_memory / rewards /transforms.py
kdemon1011's picture
Upload folder using huggingface_hub
15503f9 verified
"""
Visual Memory per-step reward transform.
Extends StepRewardTransform with game-aware scoring. Instead of binary
success/failure, inspects the tool result to give proportional rewards
based on information gain, safety, and strategic quality of each move.
Used when: --reward-mode openenv
Scoring by tool:
reveal_cell:
Safe reveal (signal/key/empty) β†’ +0.15
Hazard hit β†’ -0.40
Error (already revealed, etc.) β†’ -0.10
flag_cell:
Successful flag β†’ +0.20
Error (already flagged, etc.) β†’ -0.10
unflag_cell:
Successful unflag β†’ +0.05 (correcting a mistake is neutral-positive)
Error β†’ -0.10
submit_solution:
Correct (perfect) β†’ +1.0
Partial (precision*recall > 0) β†’ +0.3 * F1
Wrong (zero overlap) β†’ -0.50
recall_log / get_action_history:
Success β†’ +0.10 (evidence gathering)
inspect_region:
Success β†’ +0.08
Error β†’ -0.10
get_board_view / get_status / get_progress_stats:
Success β†’ +0.05 (observation, low cost)
move_viewport:
Success β†’ +0.10 (exploration in fog scenarios)
Error β†’ -0.10
load_scenario / reset_scenario / list_scenarios / get_session_info:
Always β†’ +0.0 (session management, neutral)
Distractor traps (auto_solve / peek_hidden_cell / undo_last_action):
Always β†’ -0.25 (models must learn to avoid)
"""
import json
from openenv.core.env_server.mcp_types import CallToolObservation
from openenv.core.env_server.types import Observation
from .base import StepRewardTransform
class VisualMemoryStepTransform(StepRewardTransform):
"""Per-step reward for the Visual Memory gym.
Each tool call gets a reward based on its outcome. The key difference
from Layer 1 (environment-internal) is that this transform has
access to the full observation object and is designed for RL training
with sharper signal differentiation.
"""
def _compute_reward(self, observation: Observation) -> float:
if not isinstance(observation, CallToolObservation):
return 0.0
if observation.error is not None:
return -0.5
tool_name = getattr(observation, "tool_name", "") or ""
result = self._extract_result(observation.result)
if tool_name == "reveal_cell":
return self._score_reveal(result)
if tool_name == "flag_cell":
return self._score_flag(result)
if tool_name == "unflag_cell":
return 0.05 if not self._is_error(result) else -0.10
if tool_name == "submit_solution":
return self._score_submission(result)
if tool_name in ("recall_log", "get_action_history"):
return 0.10 if not self._is_error(result) else 0.0
if tool_name == "inspect_region":
return 0.08 if not self._is_error(result) else -0.10
if tool_name in ("get_board_view", "get_status", "get_progress_stats"):
return 0.05 if not self._is_error(result) else 0.0
if tool_name == "move_viewport":
return 0.10 if not self._is_error(result) else -0.10
if tool_name in ("load_scenario", "reset_scenario", "list_scenarios", "get_session_info"):
return 0.0
if tool_name in ("auto_solve", "peek_hidden_cell", "undo_last_action"):
return -0.25
return 0.0
def _score_reveal(self, result: dict) -> float:
if not isinstance(result, dict):
return -0.10
if self._is_error(result):
return -0.10
if result.get("hazard_hit"):
return -0.40
return 0.15
def _score_flag(self, result: dict) -> float:
if not isinstance(result, dict):
return -0.10
if self._is_error(result):
return -0.10
if result.get("flagged"):
return 0.20
return 0.0
def _score_submission(self, result: dict) -> float:
if not isinstance(result, dict):
return -0.50
if self._is_error(result):
return -0.50
if result.get("correct") is True:
return 1.0
precision = result.get("precision", 0.0)
recall = result.get("recall", 0.0)
if precision + recall > 0:
f1 = 2 * precision * recall / (precision + recall)
return 0.3 * f1
keys_collected = result.get("keys_collected", 0)
keys_required = result.get("keys_required", 1)
if keys_required > 0 and keys_collected > 0:
return 0.3 * (keys_collected / keys_required)
return -0.50
@staticmethod
def _is_error(result) -> bool:
if isinstance(result, dict):
return "error" in result
return False
@staticmethod
def _extract_result(result):
if hasattr(result, "data"):
result = result.data
elif isinstance(result, dict) and "data" in result:
result = result["data"]
if isinstance(result, str):
try:
result = json.loads(result)
except (json.JSONDecodeError, TypeError):
pass
return result