File size: 6,703 Bytes
a3d65ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Tests for SupportTicketEnvironment β€” runs the environment directly
(no HTTP server required).
"""

import pytest
from support_ticket_env.server.support_environment import SupportTicketEnvironment
from support_ticket_env.models import SupportAction


# ─────────────────────────── fixtures ──────────────────────────

@pytest.fixture
def env():
    return SupportTicketEnvironment()


# ─────────────────────────── Task 1 ────────────────────────────

class TestTask1:
    def test_reset_returns_observation(self, env):
        obs = env.reset(task_id=1, seed=42)
        assert obs.ticket_text
        assert obs.task_id == 1
        assert obs.done is False

    def test_correct_classification(self, env):
        obs = env.reset(task_id=1, seed=42)
        # Find out the correct category via state
        state = env.state
        action = SupportAction(
            action_type="classify",
            category=state.correct_category,
        )
        obs = env.step(action)
        assert obs.reward == 1.0
        assert obs.done is True

    def test_wrong_classification(self, env):
        env.reset(task_id=1, seed=42)
        state = env.state
        wrong_cats = [
            c for c in ["billing", "technical", "account", "general", "refund"]
            if c != state.correct_category
        ]
        action = SupportAction(action_type="classify", category=wrong_cats[0])
        obs = env.step(action)
        assert obs.reward == 0.0
        assert obs.done is True

    def test_non_classify_action_penalised(self, env):
        env.reset(task_id=1, seed=42)
        obs = env.step(SupportAction(action_type="reply", reply_text="hello"))
        # Should not crash; done might be False and reward 0
        assert obs.reward is not None


# ─────────────────────────── Task 2 ────────────────────────────

class TestTask2:
    def test_full_correct_episode(self, env):
        env.reset(task_id=2, seed=42)
        state = env.state

        # Step 1: classify
        obs = env.step(SupportAction(
            action_type="classify",
            category=state.correct_category,
        ))
        assert obs.done is False
        assert obs.reward > 0

        # Step 2: correct action
        obs = env.step(SupportAction(action_type=state.correct_action))
        assert obs.done is True
        assert obs.reward >= 0.5

    def test_must_classify_first(self, env):
        env.reset(task_id=2, seed=7)
        obs = env.step(SupportAction(action_type="escalate"))
        assert obs.done is False
        assert "classify" in obs.feedback.lower()

    def test_state_reflects_progress(self, env):
        env.reset(task_id=2, seed=7)
        state = env.state
        assert state.classified is False

        env.step(SupportAction(
            action_type="classify",
            category=state.correct_category,
        ))
        state2 = env.state
        assert state2.classified is True
        assert state2.step_count == 1


# ─────────────────────────── Task 3 ────────────────────────────

class TestTask3:
    def test_queue_has_three_tickets(self, env):
        env.reset(task_id=3, seed=42)
        state = env.state
        assert state.tickets_total == 3
        assert state.tickets_resolved == 0

    def test_resolve_all_tickets(self, env):
        env.reset(task_id=3, seed=42)
        done = False
        steps = 0

        while not done and steps < 30:
            state = env.state
            if not state.classified:
                action = SupportAction(
                    action_type="classify",
                    category=state.correct_category,
                )
            else:
                ca = state.correct_action
                if ca == "reply":
                    action = SupportAction(
                        action_type="reply",
                        reply_text=f"We are handling your {state.correct_category} issue.",
                    )
                else:
                    action = SupportAction(action_type=ca)
            obs = env.step(action)
            done = obs.done
            steps += 1

        assert done, "Episode should finish after 3 tickets"
        final_state = env.state
        assert final_state.tickets_resolved == 3

    def test_total_reward_positive(self, env):
        env.reset(task_id=3, seed=123)
        total = 0.0
        done = False
        steps = 0

        while not done and steps < 20:
            state = env.state
            if not state.classified:
                action = SupportAction(
                    action_type="classify",
                    category=state.correct_category,
                )
            else:
                action = SupportAction(action_type=state.correct_action)
            obs = env.step(action)
            total += obs.reward or 0.0
            done = obs.done
            steps += 1

        assert total > 0.0


# ─────────────────────────── State API ─────────────────────────

class TestStateAPI:
    def test_state_after_reset(self, env):
        env.reset(task_id=1, seed=0)
        state = env.state
        assert state.step_count == 0
        assert state.task_id == 1
        assert state.ticket_id != ""

    def test_step_count_increments(self, env):
        env.reset(task_id=1, seed=0)
        state = env.state
        env.step(SupportAction(action_type="classify", category=state.correct_category))
        assert env.state.step_count == 1


# ─────────────────────────── Reward bounds ─────────────────────

class TestRewardBounds:
    def test_reward_in_range(self, env):
        for seed in [0, 1, 2, 3, 42]:
            for task_id in [1, 2, 3]:
                env.reset(task_id=task_id, seed=seed)
                state = env.state
                action = SupportAction(
                    action_type="classify",
                    category=state.correct_category,
                )
                obs = env.step(action)
                assert -1.0 <= (obs.reward or 0.0) <= 1.0, (
                    f"Reward out of bounds: {obs.reward}"
                )