Spaces:
Running
Running
Keep invalid-action task scores inside open interval
Browse files
server/environment.py
CHANGED
|
@@ -196,8 +196,9 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 196 |
allowed = set(task["allowed_fields"])
|
| 197 |
extra_fields = submitted_fields - allowed
|
| 198 |
if extra_fields:
|
| 199 |
-
# Penalty: record
|
| 200 |
-
|
|
|
|
| 201 |
self._state.average_score_so_far = self._current_average_score()
|
| 202 |
self._state.step_count += 1
|
| 203 |
self._state.current_ticket_index += 1
|
|
@@ -219,7 +220,7 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 219 |
else:
|
| 220 |
final_reward = 0.0
|
| 221 |
reward_components = self._build_reward_components(
|
| 222 |
-
ticket_score=
|
| 223 |
field_breakdown={},
|
| 224 |
shaped_step_reward=0.0,
|
| 225 |
reward_kind="trajectory" if is_done else "step_penalty",
|
|
@@ -249,7 +250,7 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 249 |
self._build_history_entry(
|
| 250 |
current_ticket,
|
| 251 |
predicted=action.model_dump(exclude_none=True),
|
| 252 |
-
score=
|
| 253 |
breakdown={},
|
| 254 |
queue_position=idx + 1,
|
| 255 |
reward=final_reward,
|
|
|
|
| 196 |
allowed = set(task["allowed_fields"])
|
| 197 |
extra_fields = submitted_fields - allowed
|
| 198 |
if extra_fields:
|
| 199 |
+
# Penalty: record an open-interval score, advance index, return penalty observation
|
| 200 |
+
invalid_score = clamp_open_unit_interval(0.0)
|
| 201 |
+
self._state.per_ticket_scores.append(invalid_score)
|
| 202 |
self._state.average_score_so_far = self._current_average_score()
|
| 203 |
self._state.step_count += 1
|
| 204 |
self._state.current_ticket_index += 1
|
|
|
|
| 220 |
else:
|
| 221 |
final_reward = 0.0
|
| 222 |
reward_components = self._build_reward_components(
|
| 223 |
+
ticket_score=invalid_score,
|
| 224 |
field_breakdown={},
|
| 225 |
shaped_step_reward=0.0,
|
| 226 |
reward_kind="trajectory" if is_done else "step_penalty",
|
|
|
|
| 250 |
self._build_history_entry(
|
| 251 |
current_ticket,
|
| 252 |
predicted=action.model_dump(exclude_none=True),
|
| 253 |
+
score=invalid_score,
|
| 254 |
breakdown={},
|
| 255 |
queue_position=idx + 1,
|
| 256 |
reward=final_reward,
|
tests/test_competitive_upgrade.py
CHANGED
|
@@ -746,9 +746,9 @@ class TestTerminalInvalidActionFinalReward(unittest.TestCase):
|
|
| 746 |
)
|
| 747 |
|
| 748 |
self.assertTrue(final_obs.done)
|
| 749 |
-
self.assertAlmostEqual(final_obs.reward, 0.
|
| 750 |
-
self.assertAlmostEqual(env.state.total_reward, 0.
|
| 751 |
-
self.assertAlmostEqual(env.state.reward or 0.0, 0.
|
| 752 |
|
| 753 |
|
| 754 |
# ---------------------------------------------------------------------------
|
|
|
|
| 746 |
)
|
| 747 |
|
| 748 |
self.assertTrue(final_obs.done)
|
| 749 |
+
self.assertAlmostEqual(final_obs.reward, 0.5, places=9)
|
| 750 |
+
self.assertAlmostEqual(env.state.total_reward, 0.5, places=9)
|
| 751 |
+
self.assertAlmostEqual(env.state.reward or 0.0, 0.5, places=9)
|
| 752 |
|
| 753 |
|
| 754 |
# ---------------------------------------------------------------------------
|
tests/test_extra_fields_penalty.py
CHANGED
|
@@ -77,8 +77,8 @@ class TestExtraFieldsPenalty(unittest.TestCase):
|
|
| 77 |
|
| 78 |
self.assertEqual(penalty_obs.tickets_processed, 1)
|
| 79 |
|
| 80 |
-
def
|
| 81 |
-
"""per_ticket_scores must
|
| 82 |
env = _make_env()
|
| 83 |
env.reset(seed=42, task_id=1)
|
| 84 |
|
|
@@ -90,7 +90,8 @@ class TestExtraFieldsPenalty(unittest.TestCase):
|
|
| 90 |
|
| 91 |
state = env.state
|
| 92 |
self.assertEqual(len(state.per_ticket_scores), 1)
|
| 93 |
-
self.
|
|
|
|
| 94 |
|
| 95 |
def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
|
| 96 |
"""History entry for a penalty step must include penalty_reason."""
|
|
@@ -107,7 +108,8 @@ class TestExtraFieldsPenalty(unittest.TestCase):
|
|
| 107 |
entry = penalty_obs.history[0]
|
| 108 |
self.assertIn("penalty_reason", entry)
|
| 109 |
self.assertIn("assignment_group", entry["penalty_reason"])
|
| 110 |
-
self.
|
|
|
|
| 111 |
|
| 112 |
def test_no_extra_fields_grades_normally(self) -> None:
|
| 113 |
"""When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
|
|
@@ -191,9 +193,10 @@ class TestExtraFieldsPenalty(unittest.TestCase):
|
|
| 191 |
final_obs = env.step(action)
|
| 192 |
|
| 193 |
self.assertTrue(final_obs.done)
|
| 194 |
-
|
| 195 |
-
self.
|
| 196 |
-
self.
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
if __name__ == "__main__":
|
|
|
|
| 77 |
|
| 78 |
self.assertEqual(penalty_obs.tickets_processed, 1)
|
| 79 |
|
| 80 |
+
def test_extra_fields_records_score_inside_open_interval(self) -> None:
|
| 81 |
+
"""per_ticket_scores must stay in the open interval after a penalty step."""
|
| 82 |
env = _make_env()
|
| 83 |
env.reset(seed=42, task_id=1)
|
| 84 |
|
|
|
|
| 90 |
|
| 91 |
state = env.state
|
| 92 |
self.assertEqual(len(state.per_ticket_scores), 1)
|
| 93 |
+
self.assertGreater(state.per_ticket_scores[0], 0.0)
|
| 94 |
+
self.assertLess(state.per_ticket_scores[0], 1.0)
|
| 95 |
|
| 96 |
def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
|
| 97 |
"""History entry for a penalty step must include penalty_reason."""
|
|
|
|
| 108 |
entry = penalty_obs.history[0]
|
| 109 |
self.assertIn("penalty_reason", entry)
|
| 110 |
self.assertIn("assignment_group", entry["penalty_reason"])
|
| 111 |
+
self.assertGreater(entry["score"], 0.0)
|
| 112 |
+
self.assertLess(entry["score"], 1.0)
|
| 113 |
|
| 114 |
def test_no_extra_fields_grades_normally(self) -> None:
|
| 115 |
"""When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
|
|
|
|
| 193 |
final_obs = env.step(action)
|
| 194 |
|
| 195 |
self.assertTrue(final_obs.done)
|
| 196 |
+
self.assertGreater(final_obs.reward, 0.0)
|
| 197 |
+
self.assertLess(final_obs.reward, 1.0)
|
| 198 |
+
self.assertGreater(env.state.total_reward, 0.0)
|
| 199 |
+
self.assertLess(env.state.total_reward, 1.0)
|
| 200 |
|
| 201 |
|
| 202 |
if __name__ == "__main__":
|