samrat-rm commited on
Commit
1288c52
·
1 Parent(s): c6888af

feat: max step limit

Browse files
server/WhyDidItFail_environment.py CHANGED
@@ -27,6 +27,7 @@ class WhyDidItFailEnvironment(Environment):
27
  self._state = State(episode_id=str(uuid4()), step_count=0)
28
  self.scenario: dict | None = None
29
  self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
 
30
 
31
  @property
32
  def state(self) -> State:
@@ -37,12 +38,17 @@ class WhyDidItFailEnvironment(Environment):
37
  self.inspection_order = []
38
 
39
  scenario_key = kwargs.get("scenario_key")
 
40
  if scenario_key and scenario_key in SCENARIOS:
41
  self.scenario = SCENARIOS[scenario_key]
42
  else:
43
  if seed is not None:
44
  random.seed(seed)
45
  self.scenario = random.choice(list(SCENARIOS.values()))
 
 
 
 
46
  return WhyDidItFailObservation(
47
  task_description=(
48
  "A training run has failed. Diagnose the root cause.\n"
@@ -62,6 +68,21 @@ class WhyDidItFailEnvironment(Environment):
62
  raise RuntimeError("Environment must be reset before calling step.")
63
 
64
  self._state.step_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  required: list[str] = self.scenario.get("required_sources", ["logs"])
66
 
67
  if action.action_type == "inspect_logs":
 
27
  self._state = State(episode_id=str(uuid4()), step_count=0)
28
  self.scenario: dict | None = None
29
  self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
30
+ self.max_steps: int = 0
31
 
32
  @property
33
  def state(self) -> State:
 
38
  self.inspection_order = []
39
 
40
  scenario_key = kwargs.get("scenario_key")
41
+
42
  if scenario_key and scenario_key in SCENARIOS:
43
  self.scenario = SCENARIOS[scenario_key]
44
  else:
45
  if seed is not None:
46
  random.seed(seed)
47
  self.scenario = random.choice(list(SCENARIOS.values()))
48
+
49
+ required_sources = self.scenario.get("required_sources", ["logs"])
50
+ self.max_steps = len(required_sources) * 3 + 2
51
+
52
  return WhyDidItFailObservation(
53
  task_description=(
54
  "A training run has failed. Diagnose the root cause.\n"
 
68
  raise RuntimeError("Environment must be reset before calling step.")
69
 
70
  self._state.step_count += 1
71
+
72
+ # Hard step limit — terminate immediately, grade() will return 0.0.
73
+ if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
74
+ return WhyDidItFailObservation(
75
+ task_description="Step limit reached. Episode terminated.",
76
+ visible_data={},
77
+ available_actions=[],
78
+ steps_taken=self._state.step_count,
79
+ reward=0.0,
80
+ done=True,
81
+ feedback=(
82
+ f"Step limit ({self.max_steps}) reached without a diagnosis. "
83
+ f"Score: 0.00. Actual failure: '{self.scenario['correct_diagnosis']}'."
84
+ ),
85
+ )
86
  required: list[str] = self.scenario.get("required_sources", ["logs"])
87
 
88
  if action.action_type == "inspect_logs":
server/graders.py CHANGED
@@ -206,7 +206,11 @@ def grade(
206
  inspection_order = inspection_order or []
207
  required_sources = scenario.get("required_sources", ["logs"]) # ordered list
208
  required = set(required_sources) # set for membership checks
209
- min_steps = len(required) + 1 # inspect all required sources + submit
 
 
 
 
210
 
211
  d_score = _diagnosis_score(diagnosis, scenario)
212
  ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)
 
206
  inspection_order = inspection_order or []
207
  required_sources = scenario.get("required_sources", ["logs"]) # ordered list
208
  required = set(required_sources) # set for membership checks
209
+ min_steps = len(required) + 1 # inspect all required sources + submit
210
+ max_steps = len(required) * 3 + 2 # hard ceiling; exceeding it = total failure
211
+
212
+ if steps_taken > max_steps:
213
+ return 0.0
214
 
215
  d_score = _diagnosis_score(diagnosis, scenario)
216
  ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)