samrat-rm commited on
Commit
9f554a9
·
1 Parent(s): 572e42a

fix(grade): keyword matching and requires_fix flag for diagnosis scoring

Browse files
server/WhyDidItFail_environment.py CHANGED
@@ -86,6 +86,28 @@ class WhyDidItFailEnvironment(Environment):
86
  reward=0.0, done=False, feedback=feedback
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def grade(self, action: WhyDidItFailAction) -> tuple[float, str, bool]:
90
  """Score a submit_diagnosis action against the current scenario."""
91
  if self.scenario is None:
@@ -95,12 +117,20 @@ class WhyDidItFailEnvironment(Environment):
95
  correct_fix = (self.scenario.get("correct_fix") or "").strip().lower()
96
  suggested_fix = (action.suggested_fix or "").strip().lower()
97
 
98
- diagnosis_correct = diagnosis == correct_diagnosis
99
- fix_correct = suggested_fix == correct_fix if correct_fix else True
 
 
 
 
 
 
 
 
100
 
101
  if diagnosis_correct and fix_correct:
102
  reward = 1.0
103
- feedback = "Correct diagnosis and fix!"
104
  elif diagnosis_correct:
105
  reward = 0.5
106
  feedback = f"Correct diagnosis, but the suggested fix was wrong. Expected: '{self.scenario.get('correct_fix')}'."
 
86
  reward=0.0, done=False, feedback=feedback
87
  )
88
 
89
+ @staticmethod
90
+ def _keywords_match(submitted: str, expected: str) -> bool:
91
+ """Return True if all significant keywords from expected appear in submitted.
92
+
93
+ Both strings should already be lowercased. Underscores and hyphens are
94
+ treated as spaces so "exploding_gradients" matches "exploding gradients".
95
+ Common stop words ("to", "a", "the", …) are ignored during keyword
96
+ extraction so that filler differences don't cause false negatives.
97
+ """
98
+ _STOP_WORDS = {"to", "a", "the", "and", "or", "is", "are", "was", "an", "in", "of"}
99
+
100
+ def _normalize(s: str) -> str:
101
+ return s.replace("_", " ").replace("-", " ")
102
+
103
+ submitted_norm = _normalize(submitted)
104
+ keywords = [
105
+ w for w in _normalize(expected).split()
106
+ if w not in _STOP_WORDS and len(w) > 1
107
+ ]
108
+ return all(kw in submitted_norm for kw in keywords)
109
+ # TODO : Partial credit scoreing, Configurable keyword aliases per scenario, False positive Gaurd,
110
+
111
  def grade(self, action: WhyDidItFailAction) -> tuple[float, str, bool]:
112
  """Score a submit_diagnosis action against the current scenario."""
113
  if self.scenario is None:
 
117
  correct_fix = (self.scenario.get("correct_fix") or "").strip().lower()
118
  suggested_fix = (action.suggested_fix or "").strip().lower()
119
 
120
+ requires_fix: bool = self.scenario.get("requires_fix", False)
121
+
122
+ diagnosis_correct = self._keywords_match(diagnosis, correct_diagnosis)
123
+ if not requires_fix:
124
+ fix_correct = True # fix not evaluated for this scenario
125
+ elif not correct_fix:
126
+ # Scenario marked requires_fix=True but forgot to set correct_fix — safe default.
127
+ fix_correct = False
128
+ else:
129
+ fix_correct = self._keywords_match(suggested_fix, correct_fix)
130
 
131
  if diagnosis_correct and fix_correct:
132
  reward = 1.0
133
+ feedback = "Correct diagnosis and fix!" if requires_fix else "Correct diagnosis!"
134
  elif diagnosis_correct:
135
  reward = 0.5
136
  feedback = f"Correct diagnosis, but the suggested fix was wrong. Expected: '{self.scenario.get('correct_fix')}'."
server/scenarios.py CHANGED
@@ -18,6 +18,7 @@ SCENARIOS = {
18
  "gradient_norms": None, # not visible until agent requests it
19
  "correct_diagnosis": "exploding_gradients",
20
  "correct_fix": "reduce learning_rate to 0.001",
 
21
  }
22
  }
23
  # TODO : Add more scenarios
 
18
  "gradient_norms": None, # not visible until agent requests it
19
  "correct_diagnosis": "exploding_gradients",
20
  "correct_fix": "reduce learning_rate to 0.001",
21
+ "requires_fix": False, # set True on hard scenarios where fix must be graded
22
  }
23
  }
24
  # TODO : Add more scenarios