File size: 6,653 Bytes
1794757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from __future__ import annotations

from fastapi.testclient import TestClient
from fastapi.middleware.cors import CORSMiddleware

from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.server import DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX, create_app
from trenches_env.session_manager import SessionManager
from trenches_env.source_ingestion import SourceHarvester


def build_manager() -> SessionManager:
    env = FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))
    return SessionManager(env=env)


def _cors_kwargs(app) -> dict[str, object]:
    middleware = next(entry for entry in app.user_middleware if entry.cls is CORSMiddleware)
    return middleware.kwargs


def test_server_defaults_to_localhost_any_port_cors(monkeypatch) -> None:
    monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGINS", raising=False)
    monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False)
    monkeypatch.delenv("TRENCHES_CORS_ALLOW_CREDENTIALS", raising=False)

    app = create_app(session_manager=build_manager())
    cors = _cors_kwargs(app)

    assert cors["allow_origins"] == []
    assert cors["allow_origin_regex"] == DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX
    assert cors["allow_credentials"] is True


def test_server_honors_explicit_cors_origin_list(monkeypatch) -> None:
    monkeypatch.setenv("TRENCHES_CORS_ALLOW_ORIGINS", "https://dashboard.example.com, https://ops.example.com")
    monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False)
    monkeypatch.setenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "false")

    app = create_app(session_manager=build_manager())
    cors = _cors_kwargs(app)

    assert cors["allow_origins"] == ["https://dashboard.example.com", "https://ops.example.com"]
    assert cors["allow_origin_regex"] is None
    assert cors["allow_credentials"] is False


def test_server_disables_credentials_for_wildcard_cors(monkeypatch) -> None:
    monkeypatch.setenv("TRENCHES_CORS_ALLOW_ORIGINS", "*")
    monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False)
    monkeypatch.setenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "true")

    app = create_app(session_manager=build_manager())
    cors = _cors_kwargs(app)

    assert cors["allow_origins"] == ["*"]
    assert cors["allow_origin_regex"] is None
    assert cors["allow_credentials"] is False


def test_server_exposes_scenarios_and_benchmark_endpoints() -> None:
    app = create_app(session_manager=build_manager())
    client = TestClient(app)

    scenarios_response = client.get("/scenarios")
    assert scenarios_response.status_code == 200
    scenarios = scenarios_response.json()
    assert any(scenario["id"] == "shipping_crisis" for scenario in scenarios)

    benchmark_response = client.post(
        "/benchmarks/run",
        json={
            "scenario_ids": ["shipping_crisis"],
            "seed": 21,
            "steps_per_scenario": 2,
        },
    )
    assert benchmark_response.status_code == 200
    benchmark = benchmark_response.json()
    assert benchmark["scenario_count"] == 1
    assert benchmark["results"][0]["scenario_id"] == "shipping_crisis"


def test_capabilities_expose_model_provider_bindings(monkeypatch) -> None:
    monkeypatch.setenv("TRENCHES_MODEL_PROVIDER_US", "huggingface")
    monkeypatch.setenv("TRENCHES_MODEL_NAME_US", "openai/gpt-oss-120b")
    app = create_app(session_manager=build_manager())
    client = TestClient(app)

    response = client.get("/capabilities")
    assert response.status_code == 200
    capabilities = response.json()
    assert capabilities["model_bindings"]["us"]["configured"] is True
    assert capabilities["model_bindings"]["us"]["decision_mode"] == "provider_ready"
    assert capabilities["model_bindings"]["us"]["provider"] == "huggingface"
    assert capabilities["model_bindings"]["us"]["api_key_env"] == "HF_TOKEN"
    assert "negotiate" in capabilities["model_bindings"]["us"]["action_tools"]


def test_server_ingests_news_and_exposes_reaction_log() -> None:
    app = create_app(session_manager=build_manager())
    client = TestClient(app)

    session_response = client.post("/sessions", json={"seed": 7})
    assert session_response.status_code == 200
    session_id = session_response.json()["session_id"]

    ingest_response = client.post(
        f"/sessions/{session_id}/news",
        json={
            "signals": [
                {
                    "source": "wire-service",
                    "headline": "Shipping risk rises in Hormuz after reported drone intercept.",
                    "region": "gulf",
                    "tags": ["shipping", "attack"],
                    "severity": 0.76,
                }
            ],
            "agent_ids": ["us", "gulf", "oversight"],
        },
    )
    assert ingest_response.status_code == 200
    payload = ingest_response.json()
    assert payload["reaction"] is not None
    assert payload["reaction"]["signals"][0]["source"] == "wire-service"
    assert payload["reaction"]["latent_event_ids"]
    assert payload["session"]["belief_state"]["gulf"]["beliefs"]
    assert {outcome["agent_id"] for outcome in payload["reaction"]["actor_outcomes"]} == {"us", "gulf", "oversight"}

    log_response = client.get(f"/sessions/{session_id}/reactions")
    assert log_response.status_code == 200
    reactions = log_response.json()
    assert len(reactions) == 1
    assert reactions[0]["event_id"] == payload["reaction"]["event_id"]


def test_server_rejects_empty_news_ingest() -> None:
    app = create_app(session_manager=build_manager())
    client = TestClient(app)

    session_response = client.post("/sessions", json={"seed": 7})
    session_id = session_response.json()["session_id"]

    ingest_response = client.post(
        f"/sessions/{session_id}/news",
        json={"signals": []},
    )
    assert ingest_response.status_code == 400
    assert "At least one external signal is required." in ingest_response.json()["detail"]


def test_server_exposes_provider_diagnostics() -> None:
    app = create_app(session_manager=build_manager())
    client = TestClient(app)

    session_response = client.post("/sessions", json={"seed": 7})
    session_id = session_response.json()["session_id"]

    diagnostics_response = client.get(f"/sessions/{session_id}/providers/diagnostics")
    assert diagnostics_response.status_code == 200
    diagnostics = diagnostics_response.json()
    us_diagnostics = next(entry for entry in diagnostics["agents"] if entry["agent_id"] == "us")

    assert us_diagnostics["agent_id"] == "us"
    assert us_diagnostics["status"] in {"idle", "fallback_only"}
    assert us_diagnostics["request_count"] == 0