File size: 8,195 Bytes
c492c3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Integration tests for the OpenENV moderation environment API.

Uses FastAPI's TestClient (httpx-backed) — no server needed.
"""
import sys
import os

# Ensure the project root is on the path so imports resolve correctly
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import pytest
from fastapi.testclient import TestClient

from api.app import app

client = TestClient(app)


# ---------------------------------------------------------------------------
# Health + meta
# ---------------------------------------------------------------------------

def test_health():
    r = client.get("/health")
    assert r.status_code == 200
    assert r.json().get("status") == "ok"


def test_tasks_list():
    r = client.get("/tasks")
    assert r.status_code == 200
    tasks = r.json()
    assert "easy_harassment" in tasks
    assert "medium_ambiguous" in tasks
    assert "hard_misinformation" in tasks


# ---------------------------------------------------------------------------
# Error cases before reset
# ---------------------------------------------------------------------------

def test_step_without_reset():
    # Use a fresh app instance to guarantee no prior episode
    from api.app import _state_manager
    _state_manager._state = None

    r = client.post("/step", json={"action_type": "allow", "parameters": {}})
    assert r.status_code == 400


def test_state_without_reset():
    from api.app import _state_manager
    _state_manager._state = None

    r = client.get("/state")
    assert r.status_code == 400


def test_grader_without_reset():
    from api.app import _state_manager
    _state_manager._state = None

    r = client.get("/grader")
    assert r.status_code == 400


def test_reset_unknown_task():
    r = client.post("/reset", json={"task_id": "nonexistent_task"})
    assert r.status_code == 400


def test_invalid_action_type():
    client.post("/reset", json={"task_id": "easy_harassment", "seed": 42})
    r = client.post("/step", json={"action_type": "explode", "parameters": {}})
    assert r.status_code == 422  # Pydantic validation error


# ---------------------------------------------------------------------------
# Easy episode — happy path
# ---------------------------------------------------------------------------

def test_easy_full_episode():
    r = client.post("/reset", json={"task_id": "easy_harassment", "seed": 42})
    assert r.status_code == 200
    obs = r.json()
    assert obs["done"] is False
    assert obs["content"] != ""
    assert obs["user_history"] is None  # not yet revealed

    # Investigate
    r = client.post("/step", json={"action_type": "fetch_user_history", "parameters": {}})
    assert r.status_code == 200
    result = r.json()
    assert result["observation"]["user_history"] is not None
    assert result["reward"] > 0

    # Classify
    r = client.post("/step", json={
        "action_type": "mark_violation_type",
        "parameters": {"violation_type": "harassment"},
    })
    assert r.status_code == 200

    # Terminal action
    r = client.post("/step", json={"action_type": "remove", "parameters": {}})
    assert r.status_code == 200
    result = r.json()
    assert result["done"] is True

    # Grade
    r = client.get("/grader")
    assert r.status_code == 200
    score = r.json()
    assert 0.0 <= score["total"] <= 1.0
    assert score["final_action_score"] == 1.0  # remove is correct for easy harassment


# ---------------------------------------------------------------------------
# Medium episode — happy path
# ---------------------------------------------------------------------------

def test_medium_full_episode():
    r = client.post("/reset", json={"task_id": "medium_ambiguous", "seed": 100})
    assert r.status_code == 200

    # Investigate: user history + policy clause
    client.post("/step", json={"action_type": "fetch_user_history", "parameters": {}})
    client.post("/step", json={"action_type": "check_policy_clause", "parameters": {}})

    # Classify
    client.post("/step", json={
        "action_type": "mark_violation_type",
        "parameters": {"violation_type": "safe"},
    })

    # Terminal
    r = client.post("/step", json={"action_type": "allow", "parameters": {}})
    assert r.status_code == 200
    assert r.json()["done"] is True

    # Grade
    r = client.get("/grader")
    score = r.json()
    assert 0.0 <= score["total"] <= 1.0


# ---------------------------------------------------------------------------
# Hard episode — happy path
# ---------------------------------------------------------------------------

def test_hard_full_episode():
    r = client.post("/reset", json={"task_id": "hard_misinformation", "seed": 10})
    assert r.status_code == 200

    for action in [
        "fetch_user_history",
        "fetch_thread_context",
        "check_policy_clause",
    ]:
        r = client.post("/step", json={"action_type": action, "parameters": {}})
        assert r.status_code == 200

    client.post("/step", json={
        "action_type": "mark_violation_type",
        "parameters": {"violation_type": "misinformation"},
    })

    r = client.post("/step", json={"action_type": "remove", "parameters": {}})
    assert r.json()["done"] is True

    r = client.get("/grader")
    score = r.json()
    assert score["total"] > 0.5   # well-investigated + correct action should score high


# ---------------------------------------------------------------------------
# Episode termination guard
# ---------------------------------------------------------------------------

def test_action_after_episode_done():
    client.post("/reset", json={"task_id": "easy_harassment", "seed": 42})
    client.post("/step", json={"action_type": "remove", "parameters": {}})

    # Try another action after episode is done
    r = client.post("/step", json={"action_type": "allow", "parameters": {}})
    assert r.status_code == 400


# ---------------------------------------------------------------------------
# Grader before episode done
# ---------------------------------------------------------------------------

def test_grader_before_done():
    client.post("/reset", json={"task_id": "easy_harassment", "seed": 42})
    r = client.get("/grader")
    assert r.status_code == 400


# ---------------------------------------------------------------------------
# /state reflects progressive reveal
# ---------------------------------------------------------------------------

def test_state_progressive_reveal():
    client.post("/reset", json={"task_id": "hard_misinformation", "seed": 777})

    r = client.get("/state")
    assert r.json()["user_history"] is None

    client.post("/step", json={"action_type": "fetch_user_history", "parameters": {}})

    r = client.get("/state")
    assert r.json()["user_history"] is not None
    assert r.json()["thread_context"] is None  # not yet revealed


# ---------------------------------------------------------------------------
# Baseline agent
# ---------------------------------------------------------------------------

def test_baseline_easy():
    r = client.get("/baseline", params={"task_id": "easy_harassment", "seed": 42})
    assert r.status_code == 200
    result = r.json()
    assert 0.0 <= result["score"]["total"] <= 1.0
    assert len(result["trajectory"]) > 0


def test_baseline_medium():
    r = client.get("/baseline", params={"task_id": "medium_ambiguous", "seed": 100})
    assert r.status_code == 200
    assert 0.0 <= r.json()["score"]["total"] <= 1.0


def test_baseline_hard():
    r = client.get("/baseline", params={"task_id": "hard_misinformation", "seed": 777})
    assert r.status_code == 200
    assert 0.0 <= r.json()["score"]["total"] <= 1.0


# ---------------------------------------------------------------------------
# Determinism — same seed → same result
# ---------------------------------------------------------------------------

def test_determinism():
    r1 = client.get("/baseline", params={"task_id": "easy_harassment", "seed": 42})
    r2 = client.get("/baseline", params={"task_id": "easy_harassment", "seed": 42})
    assert r1.json()["score"]["total"] == r2.json()["score"]["total"]