samrat-rm commited on
Commit
a0518e7
·
1 Parent(s): 9f554a9

fix: seed and episode_id in reset()

Browse files
Files changed (1) hide show
  1. server/WhyDidItFail_environment.py +6 -3
server/WhyDidItFail_environment.py CHANGED
@@ -29,10 +29,12 @@ class WhyDidItFailEnvironment(Environment):
29
  def __init__(self):
30
  self._state = State(episode_id=str(uuid4()), step_count=0)
31
  self.scenario = None
32
- self.inspected = set() # tracks what the agent has already looked at
33
 
34
  def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
35
- self._state = State(episode_id=str(uuid4()), step_count=0)
 
 
36
  self.scenario = random.choice(list(SCENARIOS.values()))
37
  self.inspected = set()
38
  return WhyDidItFailObservation(
@@ -106,10 +108,11 @@ class WhyDidItFailEnvironment(Environment):
106
  if w not in _STOP_WORDS and len(w) > 1
107
  ]
108
  return all(kw in submitted_norm for kw in keywords)
109
- # TODO : Partial credit scoreing, Configurable keyword aliases per scenario, False positive Gaurd,
110
 
111
  def grade(self, action: WhyDidItFailAction) -> tuple[float, str, bool]:
112
  """Score a submit_diagnosis action against the current scenario."""
 
113
  if self.scenario is None:
114
  raise RuntimeError("Environment must be reset before calling grade.")
115
  diagnosis = (action.diagnosis or "").strip().lower()
 
29
  def __init__(self):
30
  self._state = State(episode_id=str(uuid4()), step_count=0)
31
  self.scenario = None
32
+ self.inspected = set() # tracks what the agent has already looked at TODO : implement inspected logic
33
 
34
  def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
35
+ if seed is not None:
36
+ random.seed(seed)
37
+ self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
38
  self.scenario = random.choice(list(SCENARIOS.values()))
39
  self.inspected = set()
40
  return WhyDidItFailObservation(
 
108
  if w not in _STOP_WORDS and len(w) > 1
109
  ]
110
  return all(kw in submitted_norm for kw in keywords)
111
+ # TODO : Improve scoring : Partial credit scoreing, Configurable keyword aliases per scenario, False positive Gaurd,
112
 
113
  def grade(self, action: WhyDidItFailAction) -> tuple[float, str, bool]:
114
  """Score a submit_diagnosis action against the current scenario."""
115
+ # TODO : use step count in reward calc
116
  if self.scenario is None:
117
  raise RuntimeError("Environment must be reset before calling grade.")
118
  diagnosis = (action.diagnosis or "").strip().lower()