File size: 2,284 Bytes
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Round-trip test: launch FastAPI in TestClient, drive it via OpenSOCClient.

The client is HTTP-only and must not import server internals; this test
patches `requests` to route to the FastAPI TestClient so we can verify
the client without spinning up a real socket.
"""

from __future__ import annotations

import os
import sys
from typing import Any, Dict

import pytest

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from app_runtime import _envs, app  # noqa: E402
from client import OpenSOCClient  # noqa: E402


class _TestClientSession:
    """Adapter that gives `requests.Session` shape to a FastAPI TestClient."""

    def __init__(self):
        from fastapi.testclient import TestClient
        self.tc = TestClient(app)

    def get(self, url: str, params: Dict[str, Any] | None = None, timeout: float | None = None):
        path = url.split("//", 1)[-1]
        path = "/" + path.split("/", 1)[1] if "/" in path else "/"
        return self.tc.get(path, params=params)

    def post(self, url: str, params: Dict[str, Any] | None = None, json: Any = None, timeout: float | None = None):
        path = url.split("//", 1)[-1]
        path = "/" + path.split("/", 1)[1] if "/" in path else "/"
        return self.tc.post(path, params=params, json=json)


@pytest.fixture()
def client():
    _envs.clear()
    return OpenSOCClient(base_url="http://test", session=_TestClientSession())


class TestClient:
    def test_health(self, client):
        h = client.health()
        assert h["status"] == "ok"

    def test_tasks(self, client):
        t = client.tasks()
        assert len(t["tasks"]) == 4

    def test_round_trip(self, client):
        obs = client.reset(task="stage1_basic", mode="defender_only", seed=3)
        assert obs["role"] == "defender"
        first_log_id = obs["log_window"][0]["log_id"]
        result = client.step(
            {"submit_triage": {
                "action": "monitor",
                "cited_log_id": first_log_id,
                "rationale": "client test",
            }},
            task="stage1_basic", mode="defender_only", seed=3,
        )
        assert result["done"] is True
        grade = client.grade(task="stage1_basic", mode="defender_only", seed=3)
        assert 0.0 <= grade["score"] <= 1.0