Spaces:
Running
Running
| """ | |
| Tests for action field validation (Task 4) in HelpdeskTicketRoutingEnvironment.step(). | |
| Validates Requirement 7: Step Validates Action Fields Against Task Contract. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import sys | |
| import os | |
| import unittest | |
| import types as _types | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import openenv_test_stubs # noqa: F401 | |
| if "openenv.core.env_server.interfaces" not in sys.modules: | |
| _interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces") | |
| class _Environment: | |
| def __init__(self) -> None: | |
| pass | |
| def __init_subclass__(cls, **kwargs: object) -> None: | |
| super().__init_subclass__(**kwargs) | |
| def __class_getitem__(cls, item: object) -> type: | |
| return cls | |
| _interfaces_mod.Environment = _Environment # type: ignore[attr-defined] | |
| sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod | |
| from models import HelpdeskTicketAction, HelpdeskTicketObservation | |
| from server.environment import HelpdeskTicketRoutingEnvironment | |
| from server.tasks import TASKS | |
| from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS | |
| def _make_env() -> HelpdeskTicketRoutingEnvironment: | |
| return HelpdeskTicketRoutingEnvironment() | |
| def _task_with_issue_type_only(task_id: int) -> dict: | |
| task = dict(TASKS[task_id]) | |
| if task_id == 1: | |
| task["allowed_fields"] = ["issue_type"] | |
| return task | |
| def _restrict_task_1_fields(): | |
| original_fields = list(TASKS[1]["allowed_fields"]) | |
| TASKS[1]["allowed_fields"] = ["issue_type"] | |
| try: | |
| yield | |
| finally: | |
| TASKS[1]["allowed_fields"] = original_fields | |
| class TestExtraFieldsPenalty(unittest.TestCase): | |
| """Requirement 7: step() rejects actions with fields outside the task's allowed_fields.""" | |
| def test_extra_fields_returns_closed_interval_penalty_reward(self) -> None: | |
| """Task 1 penalties should keep the returned reward inside the unit interval.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| obs = env.reset(seed=42, task_id=1) | |
| # Task 1 allowed_fields should NOT include assignment_group | |
| self.assertNotIn("assignment_group", obs.allowed_fields) | |
| # Submit an action with an extra field (assignment_group) not in task 1's allowed_fields | |
| action = HelpdeskTicketAction( | |
| issue_type=ISSUE_TYPES[0], | |
| priority=PRIORITIES[0], | |
| assignment_group=ASSIGNMENT_GROUPS[0], # extra field | |
| ) | |
| penalty_obs = env.step(action) | |
| self.assertIsInstance(penalty_obs, HelpdeskTicketObservation) | |
| self.assertGreaterEqual(penalty_obs.reward, 0.0) | |
| self.assertLess(penalty_obs.reward, 1.0) | |
| def test_extra_fields_advances_ticket_index(self) -> None: | |
| """Penalty step must advance tickets_processed by 1.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| obs = env.reset(seed=42, task_id=1) | |
| self.assertEqual(obs.tickets_processed, 0) | |
| action = HelpdeskTicketAction( | |
| issue_type=ISSUE_TYPES[0], | |
| assignment_group=ASSIGNMENT_GROUPS[0], # extra field for task 1 | |
| ) | |
| penalty_obs = env.step(action) | |
| self.assertEqual(penalty_obs.tickets_processed, 1) | |
| def test_extra_fields_records_score_inside_unit_interval(self) -> None: | |
| """per_ticket_scores must stay in the unit interval after a penalty step.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| env.reset(seed=42, task_id=1) | |
| action = HelpdeskTicketAction( | |
| issue_type=ISSUE_TYPES[0], | |
| assignment_group=ASSIGNMENT_GROUPS[0], # extra field | |
| ) | |
| env.step(action) | |
| state = env.state | |
| self.assertEqual(len(state.per_ticket_scores), 1) | |
| self.assertGreaterEqual(state.per_ticket_scores[0], 0.0) | |
| self.assertLess(state.per_ticket_scores[0], 1.0) | |
| def test_extra_fields_history_entry_has_penalty_reason(self) -> None: | |
| """History entry for a penalty step must include penalty_reason.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| env.reset(seed=42, task_id=1) | |
| action = HelpdeskTicketAction( | |
| issue_type=ISSUE_TYPES[0], | |
| assignment_group=ASSIGNMENT_GROUPS[0], # extra field | |
| ) | |
| penalty_obs = env.step(action) | |
| self.assertEqual(len(penalty_obs.history), 1) | |
| entry = penalty_obs.history[0] | |
| self.assertIn("penalty_reason", entry) | |
| self.assertIn("assignment_group", entry["penalty_reason"]) | |
| self.assertGreaterEqual(entry["score"], 0.0) | |
| self.assertLess(entry["score"], 1.0) | |
| def test_no_extra_fields_grades_normally(self) -> None: | |
| """When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0).""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| obs = env.reset(seed=42, task_id=1) | |
| # Build action using only allowed fields | |
| allowed = obs.allowed_fields | |
| action_kwargs = {} | |
| if "issue_type" in allowed: | |
| action_kwargs["issue_type"] = ISSUE_TYPES[0] | |
| if "priority" in allowed: | |
| action_kwargs["priority"] = PRIORITIES[0] | |
| action = HelpdeskTicketAction(**action_kwargs) | |
| result_obs = env.step(action) | |
| # Should be a valid observation; reward may be any value in [0.0, 1.0] | |
| self.assertIsInstance(result_obs, HelpdeskTicketObservation) | |
| self.assertIsNotNone(result_obs.reward) | |
| # No penalty_reason in history | |
| self.assertEqual(len(result_obs.history), 1) | |
| self.assertNotIn("penalty_reason", result_obs.history[0]) | |
| def test_action_metadata_is_not_treated_as_extra_field(self) -> None: | |
| """OpenEnv Action metadata should not trigger the extra-fields penalty.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| obs = env.reset(seed=42, task_id=1) | |
| ticket_id = obs.current_ticket["ticket_id"] | |
| current_ticket = env._tickets_by_id[ticket_id] # noqa: SLF001 - test-only inspection | |
| result_obs = env.step( | |
| HelpdeskTicketAction( | |
| issue_type=current_ticket.issue_type, | |
| metadata={}, | |
| ) | |
| ) | |
| self.assertEqual(len(result_obs.history), 1) | |
| self.assertNotIn("penalty_reason", result_obs.history[0]) | |
| self.assertGreater(result_obs.history[0]["score"], 0.0) | |
| def test_extra_fields_no_exception_raised(self) -> None: | |
| """Requirement 7.4: extra fields must not raise an unhandled exception.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| env.reset(seed=42, task_id=1) | |
| action = HelpdeskTicketAction( | |
| issue_type=ISSUE_TYPES[0], | |
| priority=PRIORITIES[0], | |
| assignment_group=ASSIGNMENT_GROUPS[0], | |
| resolution_action=RESOLUTION_ACTIONS[0], # multiple extra fields | |
| ) | |
| try: | |
| obs = env.step(action) | |
| except Exception as exc: # noqa: BLE001 | |
| self.fail(f"step() raised an unexpected exception: {exc}") | |
| self.assertIsInstance(obs, HelpdeskTicketObservation) | |
| def test_extra_fields_done_flag_set_correctly_on_last_ticket(self) -> None: | |
| """When the penalty step is on the last ticket, done stays True and reward stays episode-level.""" | |
| env = _make_env() | |
| with _restrict_task_1_fields(): | |
| obs = env.reset(seed=42, task_id=1) | |
| queue_size = obs.queue_size | |
| tickets_by_id = env._tickets_by_id # noqa: SLF001 - test-only inspection | |
| # Process all tickets except the last one normally | |
| for _ in range(queue_size - 1): | |
| current_ticket_id = obs.current_ticket["ticket_id"] | |
| current_ticket = tickets_by_id[current_ticket_id] | |
| obs = env.step(HelpdeskTicketAction(issue_type=current_ticket.issue_type)) | |
| # Now trigger penalty on the last ticket | |
| current_ticket_id = obs.current_ticket["ticket_id"] | |
| current_ticket = tickets_by_id[current_ticket_id] | |
| action = HelpdeskTicketAction( | |
| issue_type=current_ticket.issue_type, | |
| assignment_group=ASSIGNMENT_GROUPS[0], # extra field | |
| ) | |
| final_obs = env.step(action) | |
| self.assertTrue(final_obs.done) | |
| self.assertGreater(final_obs.reward, 0.0) | |
| self.assertLess(final_obs.reward, 1.0) | |
| self.assertGreater(env.state.total_reward, 0.0) | |
| self.assertLess(env.state.total_reward, 1.0) | |
| if __name__ == "__main__": | |
| unittest.main() | |