Roopalgn commited on
Commit
c0d489c
·
1 Parent(s): e3dfee6

Keep invalid-action task scores inside open interval

Browse files
server/environment.py CHANGED
@@ -196,8 +196,9 @@ class HelpdeskTicketRoutingEnvironment(
196
  allowed = set(task["allowed_fields"])
197
  extra_fields = submitted_fields - allowed
198
  if extra_fields:
199
- # Penalty: record score 0.0, advance index, return penalty observation
200
- self._state.per_ticket_scores.append(0.0)
 
201
  self._state.average_score_so_far = self._current_average_score()
202
  self._state.step_count += 1
203
  self._state.current_ticket_index += 1
@@ -219,7 +220,7 @@ class HelpdeskTicketRoutingEnvironment(
219
  else:
220
  final_reward = 0.0
221
  reward_components = self._build_reward_components(
222
- ticket_score=0.0,
223
  field_breakdown={},
224
  shaped_step_reward=0.0,
225
  reward_kind="trajectory" if is_done else "step_penalty",
@@ -249,7 +250,7 @@ class HelpdeskTicketRoutingEnvironment(
249
  self._build_history_entry(
250
  current_ticket,
251
  predicted=action.model_dump(exclude_none=True),
252
- score=0.0,
253
  breakdown={},
254
  queue_position=idx + 1,
255
  reward=final_reward,
 
196
  allowed = set(task["allowed_fields"])
197
  extra_fields = submitted_fields - allowed
198
  if extra_fields:
199
+ # Penalty: record an open-interval score, advance index, return penalty observation
200
+ invalid_score = clamp_open_unit_interval(0.0)
201
+ self._state.per_ticket_scores.append(invalid_score)
202
  self._state.average_score_so_far = self._current_average_score()
203
  self._state.step_count += 1
204
  self._state.current_ticket_index += 1
 
220
  else:
221
  final_reward = 0.0
222
  reward_components = self._build_reward_components(
223
+ ticket_score=invalid_score,
224
  field_breakdown={},
225
  shaped_step_reward=0.0,
226
  reward_kind="trajectory" if is_done else "step_penalty",
 
250
  self._build_history_entry(
251
  current_ticket,
252
  predicted=action.model_dump(exclude_none=True),
253
+ score=invalid_score,
254
  breakdown={},
255
  queue_position=idx + 1,
256
  reward=final_reward,
tests/test_competitive_upgrade.py CHANGED
@@ -746,9 +746,9 @@ class TestTerminalInvalidActionFinalReward(unittest.TestCase):
746
  )
747
 
748
  self.assertTrue(final_obs.done)
749
- self.assertAlmostEqual(final_obs.reward, 0.4995, places=9)
750
- self.assertAlmostEqual(env.state.total_reward, 0.4995, places=9)
751
- self.assertAlmostEqual(env.state.reward or 0.0, 0.4995, places=9)
752
 
753
 
754
  # ---------------------------------------------------------------------------
 
746
  )
747
 
748
  self.assertTrue(final_obs.done)
749
+ self.assertAlmostEqual(final_obs.reward, 0.5, places=9)
750
+ self.assertAlmostEqual(env.state.total_reward, 0.5, places=9)
751
+ self.assertAlmostEqual(env.state.reward or 0.0, 0.5, places=9)
752
 
753
 
754
  # ---------------------------------------------------------------------------
tests/test_extra_fields_penalty.py CHANGED
@@ -77,8 +77,8 @@ class TestExtraFieldsPenalty(unittest.TestCase):
77
 
78
  self.assertEqual(penalty_obs.tickets_processed, 1)
79
 
80
- def test_extra_fields_records_score_zero(self) -> None:
81
- """per_ticket_scores must contain 0.0 after a penalty step."""
82
  env = _make_env()
83
  env.reset(seed=42, task_id=1)
84
 
@@ -90,7 +90,8 @@ class TestExtraFieldsPenalty(unittest.TestCase):
90
 
91
  state = env.state
92
  self.assertEqual(len(state.per_ticket_scores), 1)
93
- self.assertEqual(state.per_ticket_scores[0], 0.0)
 
94
 
95
  def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
96
  """History entry for a penalty step must include penalty_reason."""
@@ -107,7 +108,8 @@ class TestExtraFieldsPenalty(unittest.TestCase):
107
  entry = penalty_obs.history[0]
108
  self.assertIn("penalty_reason", entry)
109
  self.assertIn("assignment_group", entry["penalty_reason"])
110
- self.assertEqual(entry["score"], 0.0)
 
111
 
112
  def test_no_extra_fields_grades_normally(self) -> None:
113
  """When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
@@ -191,9 +193,10 @@ class TestExtraFieldsPenalty(unittest.TestCase):
191
  final_obs = env.step(action)
192
 
193
  self.assertTrue(final_obs.done)
194
- expected_reward = (queue_size - 1) / queue_size
195
- self.assertAlmostEqual(final_obs.reward, expected_reward, places=9)
196
- self.assertAlmostEqual(env.state.total_reward, expected_reward, places=9)
 
197
 
198
 
199
  if __name__ == "__main__":
 
77
 
78
  self.assertEqual(penalty_obs.tickets_processed, 1)
79
 
80
+ def test_extra_fields_records_score_inside_open_interval(self) -> None:
81
+ """per_ticket_scores must stay in the open interval after a penalty step."""
82
  env = _make_env()
83
  env.reset(seed=42, task_id=1)
84
 
 
90
 
91
  state = env.state
92
  self.assertEqual(len(state.per_ticket_scores), 1)
93
+ self.assertGreater(state.per_ticket_scores[0], 0.0)
94
+ self.assertLess(state.per_ticket_scores[0], 1.0)
95
 
96
  def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
97
  """History entry for a penalty step must include penalty_reason."""
 
108
  entry = penalty_obs.history[0]
109
  self.assertIn("penalty_reason", entry)
110
  self.assertIn("assignment_group", entry["penalty_reason"])
111
+ self.assertGreater(entry["score"], 0.0)
112
+ self.assertLess(entry["score"], 1.0)
113
 
114
  def test_no_extra_fields_grades_normally(self) -> None:
115
  """When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
 
193
  final_obs = env.step(action)
194
 
195
  self.assertTrue(final_obs.done)
196
+ self.assertGreater(final_obs.reward, 0.0)
197
+ self.assertLess(final_obs.reward, 1.0)
198
+ self.assertGreater(env.state.total_reward, 0.0)
199
+ self.assertLess(env.state.total_reward, 1.0)
200
 
201
 
202
  if __name__ == "__main__":