Spaces:
Sleeping
Sleeping
File size: 9,820 Bytes
572e42a b37875f ff8ce5f 04666da e216a2f b37875f 572e42a b37875f e216a2f a818334 1288c52 b37875f 3613ecf ff8ce5f 3613ecf 572e42a a0518e7 a818334 e216a2f 1288c52 e216a2f 1288c52 87037e2 e216a2f d3b224f e216a2f b37875f 572e42a e216a2f b37875f 1288c52 d3b224f 1288c52 d3b224f 1288c52 d3b224f 1288c52 e216a2f b37875f 572e42a e216a2f a818334 e216a2f 572e42a e216a2f a818334 e216a2f 572e42a e216a2f a818334 e216a2f 572e42a e216a2f 572e42a e216a2f 572e42a e216a2f 572e42a e216a2f d3b224f e216a2f d3b224f e216a2f b37875f a818334 d3b224f a818334 e216a2f a818334 d3b224f a818334 d3b224f a818334 e216a2f a818334 d3b224f e216a2f a818334 e216a2f a818334 d29cfdb a818334 d29cfdb a818334 e216a2f a818334 e216a2f a818334 e216a2f 572e42a a818334 e216a2f b37875f e216a2f 572e42a e216a2f 572e42a e216a2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import random
from typing import Any, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import WhyDidItFailAction, WhyDidItFailObservation, WhyDidItFailState
from server.scenarios import SCENARIOS
from server.graders import grade
class WhyDidItFailEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self.scenario: dict | None = None
self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
self.max_steps: int = 0
@property
def state(self) -> WhyDidItFailState:
return WhyDidItFailState(
episode_id=self._state.episode_id,
step_count=self._state.step_count,
scenario_key=self.scenario.get("failure_mode") if self.scenario else None,
difficulty=self.scenario.get("difficulty") if self.scenario else None,
inspection_order=list(self.inspection_order),
required_sources=list(self.scenario.get("required_sources", [])) if self.scenario else [],
max_steps=self.max_steps,
)
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
self.inspection_order = []
scenario_key = kwargs.get("scenario_key")
if scenario_key and scenario_key in SCENARIOS:
self.scenario = SCENARIOS[scenario_key]
else:
if seed is not None:
random.seed(seed)
self.scenario = random.choice(list(SCENARIOS.values()))
required_sources = self.scenario.get("required_sources", ["logs"])
self.max_steps = len(required_sources) * 3 + 2
return WhyDidItFailObservation(
task_description=(
"A training run has failed. Diagnose the root cause.\n"
f"Difficulty: {self.scenario['difficulty']}. "
"Available actions: inspect_logs, inspect_config, inspect_gradients, submit_diagnosis."
),
visible_data={"hint": "Start by inspecting the training logs."},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=0,
reward=0.10,
done=False,
feedback="Investigation started.",
)
def step(self, action: WhyDidItFailAction, timeout_s: Optional[float] = None, **kwargs: Any) -> WhyDidItFailObservation:
if self.scenario is None:
raise RuntimeError("Environment must be reset before calling step.")
self._state.step_count += 1
# Hard step limit — terminate immediately, grade() will return 0.10.
if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
return WhyDidItFailObservation(
task_description="Step limit reached. Episode terminated.",
visible_data={},
available_actions=[],
steps_taken=self._state.step_count,
reward=0.10,
done=True,
feedback=(
f"Step limit ({self.max_steps}) reached without a diagnosis. "
f"Score: 0.10. Actual failure: '{self.scenario['correct_diagnosis']}'."
),
)
required: list[str] = self.scenario.get("required_sources", ["logs"])
if action.action_type == "inspect_logs":
step_reward = self._inspect_reward("logs", required)
if "logs" not in self.inspection_order:
self.inspection_order.append("logs")
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={"training_logs": self.scenario["logs"]},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=step_reward,
done=False,
feedback=self._inspect_feedback("logs", required, step_reward),
)
elif action.action_type == "inspect_config":
step_reward = self._inspect_reward("config", required)
if "config" not in self.inspection_order:
self.inspection_order.append("config")
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={"config": self.scenario["config"]},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=step_reward,
done=False,
feedback=self._inspect_feedback("config", required, step_reward),
)
elif action.action_type == "inspect_gradients":
step_reward = self._inspect_reward("gradients", required)
if "gradients" not in self.inspection_order:
self.inspection_order.append("gradients")
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={"gradient_norms": self.scenario["gradient_norms"]},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=step_reward,
done=False,
feedback=self._inspect_feedback("gradients", required, step_reward),
)
elif action.action_type == "submit_diagnosis":
final_reward, feedback = self._grade(action)
return WhyDidItFailObservation(
task_description="Diagnosis submitted.",
visible_data={},
available_actions=[],
steps_taken=self._state.step_count,
reward=final_reward,
done=True,
feedback=feedback,
)
else:
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=0.10,
done=False,
feedback=f"Unknown action '{action.action_type}'. Minimum reward.",
)
# Rewards decay as more required sources are discovered — first clue is worth most.
# All values are in [0.10, 0.90] — no negative rewards.
_REQUIRED_STEP_REWARDS = [0.50, 0.30, 0.15]
def _inspect_reward(self, source: str, required: list[str]) -> float:
"""Return step reward for an inspect action.
Required sources: progressive — 0.50 / 0.30 / 0.15 for 1st/2nd/3rd discovery.
Irrelevant sources: 0.10 (minimum; mild penalty via contrast with required rewards).
Re-inspection: 0.10 (minimum; waste with no new information).
All values are strictly in [0.10, 0.90].
"""
if source in self.inspection_order:
return 0.10 # redundant inspection — minimum reward
if source in required:
n_found = sum(1 for s in self.inspection_order if s in required)
idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1)
return self._REQUIRED_STEP_REWARDS[idx]
return 0.10 # irrelevant source — minimum reward
def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
if source in self.inspection_order:
return f"You already examined the {label}. No new information gained."
if source in required:
remaining_sources = [s for s in required if s not in self.inspection_order and s != source]
msg = f"You examined the {label}. Relevant clue found (+{reward:.2f})."
if remaining_sources:
next_source = f"inspect_{remaining_sources[0]}"
msg += f" {len(remaining_sources)} required source(s) still unexamined. Next required action: {next_source}."
return msg
return f"You examined the {label}. This source is not required for this failure mode."
def _grade(self, action: WhyDidItFailAction) -> tuple[float, str]:
"""Delegate to the unified grade() function and return (reward, feedback)."""
assert self.scenario is not None
diagnosis = (action.diagnosis or "").strip().lower()
suggested_fix = (action.suggested_fix or "").strip().lower() or None
difficulty = self.scenario["difficulty"]
reward = grade(
diagnosis=diagnosis,
suggested_fix=suggested_fix,
scenario=self.scenario,
steps_taken=self._state.step_count,
inspection_order=self.inspection_order,
difficulty=difficulty,
)
if reward >= 0.80:
feedback = f"Excellent diagnosis! Score: {reward:.2f}"
elif reward >= 0.50:
feedback = f"Partially correct. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'."
else:
feedback = f"Incorrect diagnosis. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'."
return reward, feedback |