Spaces:
Running
Running
File size: 9,002 Bytes
42dd095 043d9e1 42dd095 043d9e1 42dd095 8241eb5 42dd095 043d9e1 42dd095 043d9e1 42dd095 043d9e1 42dd095 8241eb5 d6d9493 42dd095 043d9e1 42dd095 043d9e1 42dd095 8241eb5 42dd095 043d9e1 42dd095 043d9e1 42dd095 8241eb5 c0d489c 42dd095 043d9e1 42dd095 043d9e1 42dd095 8241eb5 c0d489c 42dd095 043d9e1 42dd095 043d9e1 42dd095 043d9e1 42dd095 8ccf96d 043d9e1 8ccf96d 42dd095 043d9e1 42dd095 c64d203 42dd095 043d9e1 c64d203 043d9e1 42dd095 c0d489c 42dd095 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """
Tests for action field validation (Task 4) in HelpdeskTicketRoutingEnvironment.step().
Validates Requirement 7: Step Validates Action Fields Against Task Contract.
"""
from __future__ import annotations
import contextlib
import sys
import os
import unittest
import types as _types
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import openenv_test_stubs # noqa: F401
if "openenv.core.env_server.interfaces" not in sys.modules:
_interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
class _Environment:
def __init__(self) -> None:
pass
def __init_subclass__(cls, **kwargs: object) -> None:
super().__init_subclass__(**kwargs)
@classmethod
def __class_getitem__(cls, item: object) -> type:
return cls
_interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
from models import HelpdeskTicketAction, HelpdeskTicketObservation
from server.environment import HelpdeskTicketRoutingEnvironment
from server.tasks import TASKS
from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS
def _make_env() -> HelpdeskTicketRoutingEnvironment:
return HelpdeskTicketRoutingEnvironment()
def _task_with_issue_type_only(task_id: int) -> dict:
task = dict(TASKS[task_id])
if task_id == 1:
task["allowed_fields"] = ["issue_type"]
return task
@contextlib.contextmanager
def _restrict_task_1_fields():
original_fields = list(TASKS[1]["allowed_fields"])
TASKS[1]["allowed_fields"] = ["issue_type"]
try:
yield
finally:
TASKS[1]["allowed_fields"] = original_fields
class TestExtraFieldsPenalty(unittest.TestCase):
"""Requirement 7: step() rejects actions with fields outside the task's allowed_fields."""
def test_extra_fields_returns_closed_interval_penalty_reward(self) -> None:
"""Task 1 penalties should keep the returned reward inside the unit interval."""
env = _make_env()
with _restrict_task_1_fields():
obs = env.reset(seed=42, task_id=1)
# Task 1 allowed_fields should NOT include assignment_group
self.assertNotIn("assignment_group", obs.allowed_fields)
# Submit an action with an extra field (assignment_group) not in task 1's allowed_fields
action = HelpdeskTicketAction(
issue_type=ISSUE_TYPES[0],
priority=PRIORITIES[0],
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
)
penalty_obs = env.step(action)
self.assertIsInstance(penalty_obs, HelpdeskTicketObservation)
self.assertGreaterEqual(penalty_obs.reward, 0.0)
self.assertLess(penalty_obs.reward, 1.0)
def test_extra_fields_advances_ticket_index(self) -> None:
"""Penalty step must advance tickets_processed by 1."""
env = _make_env()
with _restrict_task_1_fields():
obs = env.reset(seed=42, task_id=1)
self.assertEqual(obs.tickets_processed, 0)
action = HelpdeskTicketAction(
issue_type=ISSUE_TYPES[0],
assignment_group=ASSIGNMENT_GROUPS[0], # extra field for task 1
)
penalty_obs = env.step(action)
self.assertEqual(penalty_obs.tickets_processed, 1)
def test_extra_fields_records_score_inside_unit_interval(self) -> None:
"""per_ticket_scores must stay in the unit interval after a penalty step."""
env = _make_env()
with _restrict_task_1_fields():
env.reset(seed=42, task_id=1)
action = HelpdeskTicketAction(
issue_type=ISSUE_TYPES[0],
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
)
env.step(action)
state = env.state
self.assertEqual(len(state.per_ticket_scores), 1)
self.assertGreaterEqual(state.per_ticket_scores[0], 0.0)
self.assertLess(state.per_ticket_scores[0], 1.0)
def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
"""History entry for a penalty step must include penalty_reason."""
env = _make_env()
with _restrict_task_1_fields():
env.reset(seed=42, task_id=1)
action = HelpdeskTicketAction(
issue_type=ISSUE_TYPES[0],
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
)
penalty_obs = env.step(action)
self.assertEqual(len(penalty_obs.history), 1)
entry = penalty_obs.history[0]
self.assertIn("penalty_reason", entry)
self.assertIn("assignment_group", entry["penalty_reason"])
self.assertGreaterEqual(entry["score"], 0.0)
self.assertLess(entry["score"], 1.0)
def test_no_extra_fields_grades_normally(self) -> None:
"""When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
env = _make_env()
with _restrict_task_1_fields():
obs = env.reset(seed=42, task_id=1)
# Build action using only allowed fields
allowed = obs.allowed_fields
action_kwargs = {}
if "issue_type" in allowed:
action_kwargs["issue_type"] = ISSUE_TYPES[0]
if "priority" in allowed:
action_kwargs["priority"] = PRIORITIES[0]
action = HelpdeskTicketAction(**action_kwargs)
result_obs = env.step(action)
# Should be a valid observation; reward may be any value in [0.0, 1.0]
self.assertIsInstance(result_obs, HelpdeskTicketObservation)
self.assertIsNotNone(result_obs.reward)
# No penalty_reason in history
self.assertEqual(len(result_obs.history), 1)
self.assertNotIn("penalty_reason", result_obs.history[0])
def test_action_metadata_is_not_treated_as_extra_field(self) -> None:
"""OpenEnv Action metadata should not trigger the extra-fields penalty."""
env = _make_env()
with _restrict_task_1_fields():
obs = env.reset(seed=42, task_id=1)
ticket_id = obs.current_ticket["ticket_id"]
current_ticket = env._tickets_by_id[ticket_id] # noqa: SLF001 - test-only inspection
result_obs = env.step(
HelpdeskTicketAction(
issue_type=current_ticket.issue_type,
metadata={},
)
)
self.assertEqual(len(result_obs.history), 1)
self.assertNotIn("penalty_reason", result_obs.history[0])
self.assertGreater(result_obs.history[0]["score"], 0.0)
def test_extra_fields_no_exception_raised(self) -> None:
"""Requirement 7.4: extra fields must not raise an unhandled exception."""
env = _make_env()
with _restrict_task_1_fields():
env.reset(seed=42, task_id=1)
action = HelpdeskTicketAction(
issue_type=ISSUE_TYPES[0],
priority=PRIORITIES[0],
assignment_group=ASSIGNMENT_GROUPS[0],
resolution_action=RESOLUTION_ACTIONS[0], # multiple extra fields
)
try:
obs = env.step(action)
except Exception as exc: # noqa: BLE001
self.fail(f"step() raised an unexpected exception: {exc}")
self.assertIsInstance(obs, HelpdeskTicketObservation)
def test_extra_fields_done_flag_set_correctly_on_last_ticket(self) -> None:
"""When the penalty step is on the last ticket, done stays True and reward stays episode-level."""
env = _make_env()
with _restrict_task_1_fields():
obs = env.reset(seed=42, task_id=1)
queue_size = obs.queue_size
tickets_by_id = env._tickets_by_id # noqa: SLF001 - test-only inspection
# Process all tickets except the last one normally
for _ in range(queue_size - 1):
current_ticket_id = obs.current_ticket["ticket_id"]
current_ticket = tickets_by_id[current_ticket_id]
obs = env.step(HelpdeskTicketAction(issue_type=current_ticket.issue_type))
# Now trigger penalty on the last ticket
current_ticket_id = obs.current_ticket["ticket_id"]
current_ticket = tickets_by_id[current_ticket_id]
action = HelpdeskTicketAction(
issue_type=current_ticket.issue_type,
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
)
final_obs = env.step(action)
self.assertTrue(final_obs.done)
self.assertGreater(final_obs.reward, 0.0)
self.assertLess(final_obs.reward, 1.0)
self.assertGreater(env.state.total_reward, 0.0)
self.assertLess(env.state.total_reward, 1.0)
if __name__ == "__main__":
unittest.main()
|