File size: 4,074 Bytes
8ada670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations

import os
import sys
import unittest

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
SITE_PACKAGES = os.path.join(REPO_ROOT, ".venv", "Lib", "site-packages")
if SITE_PACKAGES not in sys.path:
    sys.path.insert(0, SITE_PACKAGES)

for module_name in list(sys.modules):
    if module_name == "openenv" or module_name.startswith("openenv."):
        del sys.modules[module_name]
for module_name in list(sys.modules):
    if module_name in {"models", "server.app", "server.environment", "client"}:
        del sys.modules[module_name]

try:
    from starlette.testclient import TestClient
    from server.app import app

    REAL_OPENENV_AVAILABLE = True
    IMPORT_ERROR: Exception | None = None
except Exception as exc:  # pragma: no cover - only used for skip messaging
    REAL_OPENENV_AVAILABLE = False
    IMPORT_ERROR = exc


@unittest.skipUnless(
    REAL_OPENENV_AVAILABLE,
    f"real OpenEnv stack unavailable: {IMPORT_ERROR}",
)
class RealOpenEnvIntegrationTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        cls.client = TestClient(app)

    def test_root_redirects_to_web(self) -> None:
        response = self.client.get("/", follow_redirects=False)
        self.assertEqual(response.status_code, 307)
        self.assertEqual(response.headers["location"], "/web")

    def test_grader_endpoint_scores_known_action(self) -> None:
        response = self.client.post(
            "/grader",
            json={
                "task_id": 3,
                "ticket_id": "ticket-002",
                "action": {
                    "issue_type": "identity_access",
                    "priority": "high",
                    "assignment_group": "service_desk",
                    "resolution_action": "fulfill",
                },
            },
        )
        self.assertEqual(response.status_code, 200)
        payload = response.json()
        self.assertEqual(payload["score"], 1.0)
        self.assertEqual(payload["breakdown"]["issue_type"], 1.0)

    def test_baseline_endpoint_runs_episode(self) -> None:
        response = self.client.get("/baseline", params={"task_id": 3, "seed": 42})
        self.assertEqual(response.status_code, 200)
        payload = response.json()
        self.assertEqual(payload["task_id"], 3)
        self.assertGreater(payload["step_count"], 0)
        self.assertIn("steps", payload)
        self.assertIsInstance(payload["steps"], list)

    def test_websocket_round_trip_reset_state_step(self) -> None:
        with self.client.websocket_connect("/ws") as websocket:
            websocket.send_json({"type": "reset", "data": {"task_id": 1, "seed": 42}})
            reset_message = websocket.receive_json()
            self.assertEqual(reset_message["type"], "observation")
            reset_payload = reset_message["data"]
            reset_obs = reset_payload.get("observation", reset_payload)
            self.assertEqual(reset_obs["task_id"], 1)
            self.assertFalse(reset_payload.get("done", reset_obs.get("done", False)))

            websocket.send_json({"type": "state"})
            state_message = websocket.receive_json()
            self.assertEqual(state_message["type"], "state")
            self.assertEqual(state_message["data"]["current_task_id"], 1)

            websocket.send_json(
                {
                    "type": "step",
                    "data": {
                        "issue_type": "billing_license",
                    },
                }
            )
            step_message = websocket.receive_json()
            self.assertEqual(step_message["type"], "observation")
            step_payload = step_message["data"]
            step_obs = step_payload.get("observation", step_payload)
            reward = step_payload.get("reward", step_obs.get("reward"))
            self.assertGreaterEqual(reward, 0.0)
            self.assertLessEqual(reward, 1.0)


if __name__ == "__main__":
    unittest.main()