File size: 9,002 Bytes
42dd095
 
 
 
 
 
 
043d9e1
42dd095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
043d9e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42dd095
 
 
8241eb5
 
42dd095
043d9e1
 
42dd095
043d9e1
 
42dd095
043d9e1
 
 
 
 
 
 
42dd095
 
8241eb5
d6d9493
42dd095
 
 
 
043d9e1
 
 
42dd095
043d9e1
 
 
 
 
42dd095
 
 
8241eb5
 
42dd095
043d9e1
 
42dd095
043d9e1
 
 
 
 
42dd095
 
 
8241eb5
c0d489c
42dd095
 
 
 
043d9e1
 
42dd095
043d9e1
 
 
 
 
42dd095
 
 
 
 
8241eb5
c0d489c
42dd095
 
 
 
043d9e1
 
42dd095
043d9e1
 
 
 
 
 
 
42dd095
043d9e1
 
42dd095
 
 
 
 
 
 
 
8ccf96d
 
 
043d9e1
 
 
 
 
 
 
 
 
 
8ccf96d
 
 
 
 
 
42dd095
 
 
043d9e1
 
 
 
 
 
 
 
 
 
 
 
 
42dd095
 
 
 
c64d203
42dd095
043d9e1
 
 
 
 
 
 
 
 
 
 
 
c64d203
 
043d9e1
 
 
 
 
42dd095
 
c0d489c
 
 
 
42dd095
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""
Tests for action field validation (Task 4) in HelpdeskTicketRoutingEnvironment.step().

Validates Requirement 7: Step Validates Action Fields Against Task Contract.
"""
from __future__ import annotations

import contextlib
import sys
import os
import unittest
import types as _types

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import openenv_test_stubs  # noqa: F401

if "openenv.core.env_server.interfaces" not in sys.modules:
    _interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")

    class _Environment:
        def __init__(self) -> None:
            pass

        def __init_subclass__(cls, **kwargs: object) -> None:
            super().__init_subclass__(**kwargs)

        @classmethod
        def __class_getitem__(cls, item: object) -> type:
            return cls

    _interfaces_mod.Environment = _Environment  # type: ignore[attr-defined]
    sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod

from models import HelpdeskTicketAction, HelpdeskTicketObservation
from server.environment import HelpdeskTicketRoutingEnvironment
from server.tasks import TASKS
from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS


def _make_env() -> HelpdeskTicketRoutingEnvironment:
    return HelpdeskTicketRoutingEnvironment()


def _task_with_issue_type_only(task_id: int) -> dict:
    task = dict(TASKS[task_id])
    if task_id == 1:
        task["allowed_fields"] = ["issue_type"]
    return task


@contextlib.contextmanager
def _restrict_task_1_fields():
    original_fields = list(TASKS[1]["allowed_fields"])
    TASKS[1]["allowed_fields"] = ["issue_type"]
    try:
        yield
    finally:
        TASKS[1]["allowed_fields"] = original_fields


class TestExtraFieldsPenalty(unittest.TestCase):
    """Requirement 7: step() rejects actions with fields outside the task's allowed_fields."""

    def test_extra_fields_returns_closed_interval_penalty_reward(self) -> None:
        """Task 1 penalties should keep the returned reward inside the unit interval."""
        env = _make_env()
        with _restrict_task_1_fields():
            obs = env.reset(seed=42, task_id=1)

            # Task 1 allowed_fields should NOT include assignment_group
            self.assertNotIn("assignment_group", obs.allowed_fields)

            # Submit an action with an extra field (assignment_group) not in task 1's allowed_fields
            action = HelpdeskTicketAction(
                issue_type=ISSUE_TYPES[0],
                priority=PRIORITIES[0],
                assignment_group=ASSIGNMENT_GROUPS[0],  # extra field
            )
            penalty_obs = env.step(action)

        self.assertIsInstance(penalty_obs, HelpdeskTicketObservation)
        self.assertGreaterEqual(penalty_obs.reward, 0.0)
        self.assertLess(penalty_obs.reward, 1.0)

    def test_extra_fields_advances_ticket_index(self) -> None:
        """Penalty step must advance tickets_processed by 1."""
        env = _make_env()
        with _restrict_task_1_fields():
            obs = env.reset(seed=42, task_id=1)
            self.assertEqual(obs.tickets_processed, 0)

            action = HelpdeskTicketAction(
                issue_type=ISSUE_TYPES[0],
                assignment_group=ASSIGNMENT_GROUPS[0],  # extra field for task 1
            )
            penalty_obs = env.step(action)

        self.assertEqual(penalty_obs.tickets_processed, 1)

    def test_extra_fields_records_score_inside_unit_interval(self) -> None:
        """per_ticket_scores must stay in the unit interval after a penalty step."""
        env = _make_env()
        with _restrict_task_1_fields():
            env.reset(seed=42, task_id=1)

            action = HelpdeskTicketAction(
                issue_type=ISSUE_TYPES[0],
                assignment_group=ASSIGNMENT_GROUPS[0],  # extra field
            )
            env.step(action)

        state = env.state
        self.assertEqual(len(state.per_ticket_scores), 1)
        self.assertGreaterEqual(state.per_ticket_scores[0], 0.0)
        self.assertLess(state.per_ticket_scores[0], 1.0)

    def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
        """History entry for a penalty step must include penalty_reason."""
        env = _make_env()
        with _restrict_task_1_fields():
            env.reset(seed=42, task_id=1)

            action = HelpdeskTicketAction(
                issue_type=ISSUE_TYPES[0],
                assignment_group=ASSIGNMENT_GROUPS[0],  # extra field
            )
            penalty_obs = env.step(action)

        self.assertEqual(len(penalty_obs.history), 1)
        entry = penalty_obs.history[0]
        self.assertIn("penalty_reason", entry)
        self.assertIn("assignment_group", entry["penalty_reason"])
        self.assertGreaterEqual(entry["score"], 0.0)
        self.assertLess(entry["score"], 1.0)

    def test_no_extra_fields_grades_normally(self) -> None:
        """When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
        env = _make_env()
        with _restrict_task_1_fields():
            obs = env.reset(seed=42, task_id=1)

            # Build action using only allowed fields
            allowed = obs.allowed_fields
            action_kwargs = {}
            if "issue_type" in allowed:
                action_kwargs["issue_type"] = ISSUE_TYPES[0]
            if "priority" in allowed:
                action_kwargs["priority"] = PRIORITIES[0]

            action = HelpdeskTicketAction(**action_kwargs)
            result_obs = env.step(action)

        # Should be a valid observation; reward may be any value in [0.0, 1.0]
        self.assertIsInstance(result_obs, HelpdeskTicketObservation)
        self.assertIsNotNone(result_obs.reward)
        # No penalty_reason in history
        self.assertEqual(len(result_obs.history), 1)
        self.assertNotIn("penalty_reason", result_obs.history[0])

    def test_action_metadata_is_not_treated_as_extra_field(self) -> None:
        """OpenEnv Action metadata should not trigger the extra-fields penalty."""
        env = _make_env()
        with _restrict_task_1_fields():
            obs = env.reset(seed=42, task_id=1)
            ticket_id = obs.current_ticket["ticket_id"]
            current_ticket = env._tickets_by_id[ticket_id]  # noqa: SLF001 - test-only inspection

            result_obs = env.step(
                HelpdeskTicketAction(
                    issue_type=current_ticket.issue_type,
                    metadata={},
                )
            )

        self.assertEqual(len(result_obs.history), 1)
        self.assertNotIn("penalty_reason", result_obs.history[0])
        self.assertGreater(result_obs.history[0]["score"], 0.0)

    def test_extra_fields_no_exception_raised(self) -> None:
        """Requirement 7.4: extra fields must not raise an unhandled exception."""
        env = _make_env()
        with _restrict_task_1_fields():
            env.reset(seed=42, task_id=1)

            action = HelpdeskTicketAction(
                issue_type=ISSUE_TYPES[0],
                priority=PRIORITIES[0],
                assignment_group=ASSIGNMENT_GROUPS[0],
                resolution_action=RESOLUTION_ACTIONS[0],  # multiple extra fields
            )
            try:
                obs = env.step(action)
            except Exception as exc:  # noqa: BLE001
                self.fail(f"step() raised an unexpected exception: {exc}")

        self.assertIsInstance(obs, HelpdeskTicketObservation)

    def test_extra_fields_done_flag_set_correctly_on_last_ticket(self) -> None:
        """When the penalty step is on the last ticket, done stays True and reward stays episode-level."""
        env = _make_env()
        with _restrict_task_1_fields():
            obs = env.reset(seed=42, task_id=1)
            queue_size = obs.queue_size
            tickets_by_id = env._tickets_by_id  # noqa: SLF001 - test-only inspection

            # Process all tickets except the last one normally
            for _ in range(queue_size - 1):
                current_ticket_id = obs.current_ticket["ticket_id"]
                current_ticket = tickets_by_id[current_ticket_id]
                obs = env.step(HelpdeskTicketAction(issue_type=current_ticket.issue_type))

            # Now trigger penalty on the last ticket
            current_ticket_id = obs.current_ticket["ticket_id"]
            current_ticket = tickets_by_id[current_ticket_id]
            action = HelpdeskTicketAction(
                issue_type=current_ticket.issue_type,
                assignment_group=ASSIGNMENT_GROUPS[0],  # extra field
            )
            final_obs = env.step(action)

        self.assertTrue(final_obs.done)
        self.assertGreater(final_obs.reward, 0.0)
        self.assertLess(final_obs.reward, 1.0)
        self.assertGreater(env.state.total_reward, 0.0)
        self.assertLess(env.state.total_reward, 1.0)


if __name__ == "__main__":
    unittest.main()