""" Tests for action field validation (Task 4) in HelpdeskTicketRoutingEnvironment.step(). Validates Requirement 7: Step Validates Action Fields Against Task Contract. """ from __future__ import annotations import contextlib import sys import os import unittest import types as _types sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import openenv_test_stubs # noqa: F401 if "openenv.core.env_server.interfaces" not in sys.modules: _interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces") class _Environment: def __init__(self) -> None: pass def __init_subclass__(cls, **kwargs: object) -> None: super().__init_subclass__(**kwargs) @classmethod def __class_getitem__(cls, item: object) -> type: return cls _interfaces_mod.Environment = _Environment # type: ignore[attr-defined] sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod from models import HelpdeskTicketAction, HelpdeskTicketObservation from server.environment import HelpdeskTicketRoutingEnvironment from server.tasks import TASKS from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS def _make_env() -> HelpdeskTicketRoutingEnvironment: return HelpdeskTicketRoutingEnvironment() def _task_with_issue_type_only(task_id: int) -> dict: task = dict(TASKS[task_id]) if task_id == 1: task["allowed_fields"] = ["issue_type"] return task @contextlib.contextmanager def _restrict_task_1_fields(): original_fields = list(TASKS[1]["allowed_fields"]) TASKS[1]["allowed_fields"] = ["issue_type"] try: yield finally: TASKS[1]["allowed_fields"] = original_fields class TestExtraFieldsPenalty(unittest.TestCase): """Requirement 7: step() rejects actions with fields outside the task's allowed_fields.""" def test_extra_fields_returns_closed_interval_penalty_reward(self) -> None: """Task 1 penalties should keep the returned reward inside the unit interval.""" env = _make_env() with _restrict_task_1_fields(): obs = env.reset(seed=42, task_id=1) # Task 1 allowed_fields should NOT include assignment_group self.assertNotIn("assignment_group", obs.allowed_fields) # Submit an action with an extra field (assignment_group) not in task 1's allowed_fields action = HelpdeskTicketAction( issue_type=ISSUE_TYPES[0], priority=PRIORITIES[0], assignment_group=ASSIGNMENT_GROUPS[0], # extra field ) penalty_obs = env.step(action) self.assertIsInstance(penalty_obs, HelpdeskTicketObservation) self.assertGreaterEqual(penalty_obs.reward, 0.0) self.assertLess(penalty_obs.reward, 1.0) def test_extra_fields_advances_ticket_index(self) -> None: """Penalty step must advance tickets_processed by 1.""" env = _make_env() with _restrict_task_1_fields(): obs = env.reset(seed=42, task_id=1) self.assertEqual(obs.tickets_processed, 0) action = HelpdeskTicketAction( issue_type=ISSUE_TYPES[0], assignment_group=ASSIGNMENT_GROUPS[0], # extra field for task 1 ) penalty_obs = env.step(action) self.assertEqual(penalty_obs.tickets_processed, 1) def test_extra_fields_records_score_inside_unit_interval(self) -> None: """per_ticket_scores must stay in the unit interval after a penalty step.""" env = _make_env() with _restrict_task_1_fields(): env.reset(seed=42, task_id=1) action = HelpdeskTicketAction( issue_type=ISSUE_TYPES[0], assignment_group=ASSIGNMENT_GROUPS[0], # extra field ) env.step(action) state = env.state self.assertEqual(len(state.per_ticket_scores), 1) self.assertGreaterEqual(state.per_ticket_scores[0], 0.0) self.assertLess(state.per_ticket_scores[0], 1.0) def test_extra_fields_history_entry_has_penalty_reason(self) -> None: """History entry for a penalty step must include penalty_reason.""" env = _make_env() with _restrict_task_1_fields(): env.reset(seed=42, task_id=1) action = HelpdeskTicketAction( issue_type=ISSUE_TYPES[0], assignment_group=ASSIGNMENT_GROUPS[0], # extra field ) penalty_obs = env.step(action) self.assertEqual(len(penalty_obs.history), 1) entry = penalty_obs.history[0] self.assertIn("penalty_reason", entry) self.assertIn("assignment_group", entry["penalty_reason"]) self.assertGreaterEqual(entry["score"], 0.0) self.assertLess(entry["score"], 1.0) def test_no_extra_fields_grades_normally(self) -> None: """When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0).""" env = _make_env() with _restrict_task_1_fields(): obs = env.reset(seed=42, task_id=1) # Build action using only allowed fields allowed = obs.allowed_fields action_kwargs = {} if "issue_type" in allowed: action_kwargs["issue_type"] = ISSUE_TYPES[0] if "priority" in allowed: action_kwargs["priority"] = PRIORITIES[0] action = HelpdeskTicketAction(**action_kwargs) result_obs = env.step(action) # Should be a valid observation; reward may be any value in [0.0, 1.0] self.assertIsInstance(result_obs, HelpdeskTicketObservation) self.assertIsNotNone(result_obs.reward) # No penalty_reason in history self.assertEqual(len(result_obs.history), 1) self.assertNotIn("penalty_reason", result_obs.history[0]) def test_action_metadata_is_not_treated_as_extra_field(self) -> None: """OpenEnv Action metadata should not trigger the extra-fields penalty.""" env = _make_env() with _restrict_task_1_fields(): obs = env.reset(seed=42, task_id=1) ticket_id = obs.current_ticket["ticket_id"] current_ticket = env._tickets_by_id[ticket_id] # noqa: SLF001 - test-only inspection result_obs = env.step( HelpdeskTicketAction( issue_type=current_ticket.issue_type, metadata={}, ) ) self.assertEqual(len(result_obs.history), 1) self.assertNotIn("penalty_reason", result_obs.history[0]) self.assertGreater(result_obs.history[0]["score"], 0.0) def test_extra_fields_no_exception_raised(self) -> None: """Requirement 7.4: extra fields must not raise an unhandled exception.""" env = _make_env() with _restrict_task_1_fields(): env.reset(seed=42, task_id=1) action = HelpdeskTicketAction( issue_type=ISSUE_TYPES[0], priority=PRIORITIES[0], assignment_group=ASSIGNMENT_GROUPS[0], resolution_action=RESOLUTION_ACTIONS[0], # multiple extra fields ) try: obs = env.step(action) except Exception as exc: # noqa: BLE001 self.fail(f"step() raised an unexpected exception: {exc}") self.assertIsInstance(obs, HelpdeskTicketObservation) def test_extra_fields_done_flag_set_correctly_on_last_ticket(self) -> None: """When the penalty step is on the last ticket, done stays True and reward stays episode-level.""" env = _make_env() with _restrict_task_1_fields(): obs = env.reset(seed=42, task_id=1) queue_size = obs.queue_size tickets_by_id = env._tickets_by_id # noqa: SLF001 - test-only inspection # Process all tickets except the last one normally for _ in range(queue_size - 1): current_ticket_id = obs.current_ticket["ticket_id"] current_ticket = tickets_by_id[current_ticket_id] obs = env.step(HelpdeskTicketAction(issue_type=current_ticket.issue_type)) # Now trigger penalty on the last ticket current_ticket_id = obs.current_ticket["ticket_id"] current_ticket = tickets_by_id[current_ticket_id] action = HelpdeskTicketAction( issue_type=current_ticket.issue_type, assignment_group=ASSIGNMENT_GROUPS[0], # extra field ) final_obs = env.step(action) self.assertTrue(final_obs.done) self.assertGreater(final_obs.reward, 0.0) self.assertLess(final_obs.reward, 1.0) self.assertGreater(env.state.total_reward, 0.0) self.assertLess(env.state.total_reward, 1.0) if __name__ == "__main__": unittest.main()