Spaces:
Sleeping
Sleeping
feat: max step limit
Browse files- server/WhyDidItFail_environment.py +21 -0
- server/graders.py +5 -1
server/WhyDidItFail_environment.py
CHANGED
|
@@ -27,6 +27,7 @@ class WhyDidItFailEnvironment(Environment):
|
|
| 27 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 28 |
self.scenario: dict | None = None
|
| 29 |
self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
|
|
|
|
| 30 |
|
| 31 |
@property
|
| 32 |
def state(self) -> State:
|
|
@@ -37,12 +38,17 @@ class WhyDidItFailEnvironment(Environment):
|
|
| 37 |
self.inspection_order = []
|
| 38 |
|
| 39 |
scenario_key = kwargs.get("scenario_key")
|
|
|
|
| 40 |
if scenario_key and scenario_key in SCENARIOS:
|
| 41 |
self.scenario = SCENARIOS[scenario_key]
|
| 42 |
else:
|
| 43 |
if seed is not None:
|
| 44 |
random.seed(seed)
|
| 45 |
self.scenario = random.choice(list(SCENARIOS.values()))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
return WhyDidItFailObservation(
|
| 47 |
task_description=(
|
| 48 |
"A training run has failed. Diagnose the root cause.\n"
|
|
@@ -62,6 +68,21 @@ class WhyDidItFailEnvironment(Environment):
|
|
| 62 |
raise RuntimeError("Environment must be reset before calling step.")
|
| 63 |
|
| 64 |
self._state.step_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
required: list[str] = self.scenario.get("required_sources", ["logs"])
|
| 66 |
|
| 67 |
if action.action_type == "inspect_logs":
|
|
|
|
| 27 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 28 |
self.scenario: dict | None = None
|
| 29 |
self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
|
| 30 |
+
self.max_steps: int = 0
|
| 31 |
|
| 32 |
@property
|
| 33 |
def state(self) -> State:
|
|
|
|
| 38 |
self.inspection_order = []
|
| 39 |
|
| 40 |
scenario_key = kwargs.get("scenario_key")
|
| 41 |
+
|
| 42 |
if scenario_key and scenario_key in SCENARIOS:
|
| 43 |
self.scenario = SCENARIOS[scenario_key]
|
| 44 |
else:
|
| 45 |
if seed is not None:
|
| 46 |
random.seed(seed)
|
| 47 |
self.scenario = random.choice(list(SCENARIOS.values()))
|
| 48 |
+
|
| 49 |
+
required_sources = self.scenario.get("required_sources", ["logs"])
|
| 50 |
+
self.max_steps = len(required_sources) * 3 + 2
|
| 51 |
+
|
| 52 |
return WhyDidItFailObservation(
|
| 53 |
task_description=(
|
| 54 |
"A training run has failed. Diagnose the root cause.\n"
|
|
|
|
| 68 |
raise RuntimeError("Environment must be reset before calling step.")
|
| 69 |
|
| 70 |
self._state.step_count += 1
|
| 71 |
+
|
| 72 |
+
# Hard step limit — terminate immediately, grade() will return 0.0.
|
| 73 |
+
if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
|
| 74 |
+
return WhyDidItFailObservation(
|
| 75 |
+
task_description="Step limit reached. Episode terminated.",
|
| 76 |
+
visible_data={},
|
| 77 |
+
available_actions=[],
|
| 78 |
+
steps_taken=self._state.step_count,
|
| 79 |
+
reward=0.0,
|
| 80 |
+
done=True,
|
| 81 |
+
feedback=(
|
| 82 |
+
f"Step limit ({self.max_steps}) reached without a diagnosis. "
|
| 83 |
+
f"Score: 0.00. Actual failure: '{self.scenario['correct_diagnosis']}'."
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
required: list[str] = self.scenario.get("required_sources", ["logs"])
|
| 87 |
|
| 88 |
if action.action_type == "inspect_logs":
|
server/graders.py
CHANGED
|
@@ -206,7 +206,11 @@ def grade(
|
|
| 206 |
inspection_order = inspection_order or []
|
| 207 |
required_sources = scenario.get("required_sources", ["logs"]) # ordered list
|
| 208 |
required = set(required_sources) # set for membership checks
|
| 209 |
-
min_steps = len(required) + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
d_score = _diagnosis_score(diagnosis, scenario)
|
| 212 |
ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)
|
|
|
|
| 206 |
inspection_order = inspection_order or []
|
| 207 |
required_sources = scenario.get("required_sources", ["logs"]) # ordered list
|
| 208 |
required = set(required_sources) # set for membership checks
|
| 209 |
+
min_steps = len(required) + 1 # inspect all required sources + submit
|
| 210 |
+
max_steps = len(required) * 3 + 2 # hard ceiling; exceeding it = total failure
|
| 211 |
+
|
| 212 |
+
if steps_taken > max_steps:
|
| 213 |
+
return 0.0
|
| 214 |
|
| 215 |
d_score = _diagnosis_score(diagnosis, scenario)
|
| 216 |
ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)
|