File size: 6,914 Bytes
2414d31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""Integration tests for ClarifyEnvironment — full episode flows.

Requires full openenv stack installed. Run with:
    python -m pytest tests/test_environment.py -v
Skip with:
    python -m pytest -m 'not integration'
"""
from __future__ import annotations

import json

import pytest

pytestmark = pytest.mark.integration

from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction

from server.clarify_environment import ClarifyEnvironment
from server.scenarios import REQUIRED_KEYS_BY_FAMILY


@pytest.fixture
def env():
    e = ClarifyEnvironment()
    e.reset(seed=7, task_id="medium")
    return e


def _call(env, tool, **kwargs):
    action = CallToolAction(tool_name=tool, arguments=kwargs)
    return env.step(action)


class TestReset:
    def test_returns_observation(self):
        e = ClarifyEnvironment()
        obs = e.reset(seed=42, task_id="easy")
        assert obs.done is False
        assert obs.reward == 0.0
        assert obs.result["type"] == "task"

    def test_state_after_reset(self):
        e = ClarifyEnvironment()
        e.reset(seed=42, task_id="medium")
        s = e.state
        assert s.step_count == 0
        assert s.questions_remaining == 6
        assert s.episode_done is False
        assert s.plan_submitted is False

    @pytest.mark.parametrize("diff", ["easy", "medium", "hard"])
    def test_all_difficulties(self, diff):
        e = ClarifyEnvironment()
        obs = e.reset(seed=1, task_id=diff)
        assert obs.result["task_id"] == diff


class TestListTools:
    def test_returns_three_tools(self, env):
        obs = env.step(ListToolsAction())
        names = [t.name for t in obs.tools]
        assert set(names) == {"get_task_info", "ask_question", "propose_plan"}


class TestGetTaskInfo:
    def test_free_action(self, env):
        obs = _call(env, "get_task_info")
        assert obs.reward == 0.0
        assert obs.done is False
        data = obs.result.data
        assert "request" in data
        assert "family" in data

    def test_step_count_increments(self, env):
        _call(env, "get_task_info")
        assert env.state.step_count == 1


class TestAskQuestion:
    def test_reveals_field(self, env):
        obs = _call(env, "ask_question", question="what is the order id?")
        data = obs.result.data
        assert data["field_revealed"] is not None
        assert obs.reward > 0

    def test_budget_decrements(self, env):
        _call(env, "ask_question", question="what is the order id?")
        assert env.state.questions_remaining == 5

    def test_duplicate_penalty(self, env):
        _call(env, "ask_question", question="what is the order id?")
        obs2 = _call(env, "ask_question", question="tell me order id again?")
        assert obs2.reward < 0
        assert obs2.result.data["duplicate"] is True

    def test_unknown_question_small_reward(self, env):
        obs = _call(env, "ask_question", question="tell me about your cat")
        assert obs.reward > 0
        assert obs.result.data["field_revealed"] is None

    def test_over_cap(self, env):
        for i in range(6):
            _call(env, "ask_question", question=f"question {i}")
        obs = _call(env, "ask_question", question="one more")
        assert obs.result.data["over_cap"] is True
        assert obs.done is True

    def test_truncates_long_question(self, env):
        long_q = "x" * 500
        obs = _call(env, "ask_question", question=long_q)
        assert obs.done is False


class TestProposePlan:
    def test_terminates_episode(self, env):
        plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
        obs = _call(env, "propose_plan", plan=plan)
        assert obs.done is True
        assert env.state.episode_done is True
        assert env.state.plan_submitted is True
        assert env.state.final_score is not None

    def test_bad_json_zero_score(self, env):
        obs = _call(env, "propose_plan", plan="not json")
        assert obs.done is True
        assert env.state.final_score == 0.0

    def test_missing_keys_zero_score(self, env):
        obs = _call(env, "propose_plan", plan='{"order_id": "#1"}')
        assert obs.done is True
        assert env.state.final_score == 0.0

    def test_breakdown_populated(self, env):
        plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
        _call(env, "propose_plan", plan=plan)
        bd = env.state.score_breakdown
        assert "FormatCheckRubric" in bd


class TestEpisodeDoneGuard:
    def test_no_ask_after_plan(self, env):
        plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
        _call(env, "propose_plan", plan=plan)
        obs = _call(env, "ask_question", question="more info?")
        assert obs.done is True
        assert obs.result.data.get("error") == "episode already ended"

    def test_no_plan_after_plan(self, env):
        plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
        _call(env, "propose_plan", plan=plan)
        obs = _call(env, "propose_plan", plan=plan)
        assert obs.done is True
        assert obs.result.data.get("error") == "episode already ended"

    def test_no_info_after_plan(self, env):
        plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
        _call(env, "propose_plan", plan=plan)
        obs = _call(env, "get_task_info")
        assert obs.done is True


class TestMaxSteps:
    def test_max_steps_enforced(self):
        e = ClarifyEnvironment()
        e.reset(seed=7, task_id="easy")
        for i in range(10):
            obs = _call(e, "get_task_info")
            if obs.done:
                assert e.state.step_count <= 8
                return
        pytest.fail("max_steps not enforced within 10 calls")


class TestStepCount:
    def test_increments_each_call(self, env):
        for i in range(1, 4):
            _call(env, "get_task_info")
            assert env.state.step_count == i


class TestOraclePolicy:
    def test_oracle_scores_high(self):
        for seed in range(10):
            e = ClarifyEnvironment()
            obs = e.reset(seed=seed, task_id="medium")
            family = obs.result["family"]
            profile = e._scenario["hidden_profile"]
            critical = e._scenario["critical_fields"]
            required_keys = REQUIRED_KEYS_BY_FAMILY[family]

            for cf in critical:
                kw = cf.replace("_", " ")
                _call(e, "ask_question", question=f"what is the {kw}?")

            plan = json.dumps(profile)
            obs = _call(e, "propose_plan", plan=plan)
            assert obs.done is True
            score = e.state.final_score
            assert score is not None
            assert score > 0.5, f"seed={seed} family={family} oracle score={score} too low"