import random from typing import Any, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from models import WhyDidItFailAction, WhyDidItFailObservation, WhyDidItFailState from server.scenarios import SCENARIOS from server.graders import grade class WhyDidItFailEnvironment(Environment): SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): self._state = State(episode_id=str(uuid4()), step_count=0) self.scenario: dict | None = None self.inspection_order: list[str] = [] # first-visit order; doubles as membership check self.max_steps: int = 0 @property def state(self) -> WhyDidItFailState: return WhyDidItFailState( episode_id=self._state.episode_id, step_count=self._state.step_count, scenario_key=self.scenario.get("failure_mode") if self.scenario else None, difficulty=self.scenario.get("difficulty") if self.scenario else None, inspection_order=list(self.inspection_order), required_sources=list(self.scenario.get("required_sources", [])) if self.scenario else [], max_steps=self.max_steps, ) def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation: self._state = State(episode_id=episode_id or str(uuid4()), step_count=0) self.inspection_order = [] scenario_key = kwargs.get("scenario_key") if scenario_key and scenario_key in SCENARIOS: self.scenario = SCENARIOS[scenario_key] else: if seed is not None: random.seed(seed) self.scenario = random.choice(list(SCENARIOS.values())) required_sources = self.scenario.get("required_sources", ["logs"]) self.max_steps = len(required_sources) * 3 + 2 return WhyDidItFailObservation( task_description=( "A training run has failed. Diagnose the root cause.\n" f"Difficulty: {self.scenario['difficulty']}. " "Available actions: inspect_logs, inspect_config, inspect_gradients, submit_diagnosis." ), visible_data={"hint": "Start by inspecting the training logs."}, available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"], steps_taken=0, reward=0.10, done=False, feedback="Investigation started.", ) def step(self, action: WhyDidItFailAction, timeout_s: Optional[float] = None, **kwargs: Any) -> WhyDidItFailObservation: if self.scenario is None: raise RuntimeError("Environment must be reset before calling step.") self._state.step_count += 1 # Hard step limit — terminate immediately, grade() will return 0.10. if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis": return WhyDidItFailObservation( task_description="Step limit reached. Episode terminated.", visible_data={}, available_actions=[], steps_taken=self._state.step_count, reward=0.10, done=True, feedback=( f"Step limit ({self.max_steps}) reached without a diagnosis. " f"Score: 0.10. Actual failure: '{self.scenario['correct_diagnosis']}'." ), ) required: list[str] = self.scenario.get("required_sources", ["logs"]) if action.action_type == "inspect_logs": step_reward = self._inspect_reward("logs", required) if "logs" not in self.inspection_order: self.inspection_order.append("logs") return WhyDidItFailObservation( task_description="Continue your investigation.", visible_data={"training_logs": self.scenario["logs"]}, available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"], steps_taken=self._state.step_count, reward=step_reward, done=False, feedback=self._inspect_feedback("logs", required, step_reward), ) elif action.action_type == "inspect_config": step_reward = self._inspect_reward("config", required) if "config" not in self.inspection_order: self.inspection_order.append("config") return WhyDidItFailObservation( task_description="Continue your investigation.", visible_data={"config": self.scenario["config"]}, available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"], steps_taken=self._state.step_count, reward=step_reward, done=False, feedback=self._inspect_feedback("config", required, step_reward), ) elif action.action_type == "inspect_gradients": step_reward = self._inspect_reward("gradients", required) if "gradients" not in self.inspection_order: self.inspection_order.append("gradients") return WhyDidItFailObservation( task_description="Continue your investigation.", visible_data={"gradient_norms": self.scenario["gradient_norms"]}, available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"], steps_taken=self._state.step_count, reward=step_reward, done=False, feedback=self._inspect_feedback("gradients", required, step_reward), ) elif action.action_type == "submit_diagnosis": final_reward, feedback = self._grade(action) return WhyDidItFailObservation( task_description="Diagnosis submitted.", visible_data={}, available_actions=[], steps_taken=self._state.step_count, reward=final_reward, done=True, feedback=feedback, ) else: return WhyDidItFailObservation( task_description="Continue your investigation.", visible_data={}, available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"], steps_taken=self._state.step_count, reward=0.10, done=False, feedback=f"Unknown action '{action.action_type}'. Minimum reward.", ) # Rewards decay as more required sources are discovered — first clue is worth most. # All values are in [0.10, 0.90] — no negative rewards. _REQUIRED_STEP_REWARDS = [0.50, 0.30, 0.15] def _inspect_reward(self, source: str, required: list[str]) -> float: """Return step reward for an inspect action. Required sources: progressive — 0.50 / 0.30 / 0.15 for 1st/2nd/3rd discovery. Irrelevant sources: 0.10 (minimum; mild penalty via contrast with required rewards). Re-inspection: 0.10 (minimum; waste with no new information). All values are strictly in [0.10, 0.90]. """ if source in self.inspection_order: return 0.10 # redundant inspection — minimum reward if source in required: n_found = sum(1 for s in self.inspection_order if s in required) idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1) return self._REQUIRED_STEP_REWARDS[idx] return 0.10 # irrelevant source — minimum reward def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str: label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source] if source in self.inspection_order: return f"You already examined the {label}. No new information gained." if source in required: remaining_sources = [s for s in required if s not in self.inspection_order and s != source] msg = f"You examined the {label}. Relevant clue found (+{reward:.2f})." if remaining_sources: next_source = f"inspect_{remaining_sources[0]}" msg += f" {len(remaining_sources)} required source(s) still unexamined. Next required action: {next_source}." return msg return f"You examined the {label}. This source is not required for this failure mode." def _grade(self, action: WhyDidItFailAction) -> tuple[float, str]: """Delegate to the unified grade() function and return (reward, feedback).""" assert self.scenario is not None diagnosis = (action.diagnosis or "").strip().lower() suggested_fix = (action.suggested_fix or "").strip().lower() or None difficulty = self.scenario["difficulty"] reward = grade( diagnosis=diagnosis, suggested_fix=suggested_fix, scenario=self.scenario, steps_taken=self._state.step_count, inspection_order=self.inspection_order, difficulty=difficulty, ) if reward >= 0.80: feedback = f"Excellent diagnosis! Score: {reward:.2f}" elif reward >= 0.50: feedback = f"Partially correct. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'." else: feedback = f"Incorrect diagnosis. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'." return reward, feedback