File size: 6,653 Bytes
1794757 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | from __future__ import annotations
from fastapi.testclient import TestClient
from fastapi.middleware.cors import CORSMiddleware
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.server import DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX, create_app
from trenches_env.session_manager import SessionManager
from trenches_env.source_ingestion import SourceHarvester
def build_manager() -> SessionManager:
env = FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))
return SessionManager(env=env)
def _cors_kwargs(app) -> dict[str, object]:
middleware = next(entry for entry in app.user_middleware if entry.cls is CORSMiddleware)
return middleware.kwargs
def test_server_defaults_to_localhost_any_port_cors(monkeypatch) -> None:
monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGINS", raising=False)
monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False)
monkeypatch.delenv("TRENCHES_CORS_ALLOW_CREDENTIALS", raising=False)
app = create_app(session_manager=build_manager())
cors = _cors_kwargs(app)
assert cors["allow_origins"] == []
assert cors["allow_origin_regex"] == DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX
assert cors["allow_credentials"] is True
def test_server_honors_explicit_cors_origin_list(monkeypatch) -> None:
monkeypatch.setenv("TRENCHES_CORS_ALLOW_ORIGINS", "https://dashboard.example.com, https://ops.example.com")
monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False)
monkeypatch.setenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "false")
app = create_app(session_manager=build_manager())
cors = _cors_kwargs(app)
assert cors["allow_origins"] == ["https://dashboard.example.com", "https://ops.example.com"]
assert cors["allow_origin_regex"] is None
assert cors["allow_credentials"] is False
def test_server_disables_credentials_for_wildcard_cors(monkeypatch) -> None:
monkeypatch.setenv("TRENCHES_CORS_ALLOW_ORIGINS", "*")
monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False)
monkeypatch.setenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "true")
app = create_app(session_manager=build_manager())
cors = _cors_kwargs(app)
assert cors["allow_origins"] == ["*"]
assert cors["allow_origin_regex"] is None
assert cors["allow_credentials"] is False
def test_server_exposes_scenarios_and_benchmark_endpoints() -> None:
app = create_app(session_manager=build_manager())
client = TestClient(app)
scenarios_response = client.get("/scenarios")
assert scenarios_response.status_code == 200
scenarios = scenarios_response.json()
assert any(scenario["id"] == "shipping_crisis" for scenario in scenarios)
benchmark_response = client.post(
"/benchmarks/run",
json={
"scenario_ids": ["shipping_crisis"],
"seed": 21,
"steps_per_scenario": 2,
},
)
assert benchmark_response.status_code == 200
benchmark = benchmark_response.json()
assert benchmark["scenario_count"] == 1
assert benchmark["results"][0]["scenario_id"] == "shipping_crisis"
def test_capabilities_expose_model_provider_bindings(monkeypatch) -> None:
monkeypatch.setenv("TRENCHES_MODEL_PROVIDER_US", "huggingface")
monkeypatch.setenv("TRENCHES_MODEL_NAME_US", "openai/gpt-oss-120b")
app = create_app(session_manager=build_manager())
client = TestClient(app)
response = client.get("/capabilities")
assert response.status_code == 200
capabilities = response.json()
assert capabilities["model_bindings"]["us"]["configured"] is True
assert capabilities["model_bindings"]["us"]["decision_mode"] == "provider_ready"
assert capabilities["model_bindings"]["us"]["provider"] == "huggingface"
assert capabilities["model_bindings"]["us"]["api_key_env"] == "HF_TOKEN"
assert "negotiate" in capabilities["model_bindings"]["us"]["action_tools"]
def test_server_ingests_news_and_exposes_reaction_log() -> None:
app = create_app(session_manager=build_manager())
client = TestClient(app)
session_response = client.post("/sessions", json={"seed": 7})
assert session_response.status_code == 200
session_id = session_response.json()["session_id"]
ingest_response = client.post(
f"/sessions/{session_id}/news",
json={
"signals": [
{
"source": "wire-service",
"headline": "Shipping risk rises in Hormuz after reported drone intercept.",
"region": "gulf",
"tags": ["shipping", "attack"],
"severity": 0.76,
}
],
"agent_ids": ["us", "gulf", "oversight"],
},
)
assert ingest_response.status_code == 200
payload = ingest_response.json()
assert payload["reaction"] is not None
assert payload["reaction"]["signals"][0]["source"] == "wire-service"
assert payload["reaction"]["latent_event_ids"]
assert payload["session"]["belief_state"]["gulf"]["beliefs"]
assert {outcome["agent_id"] for outcome in payload["reaction"]["actor_outcomes"]} == {"us", "gulf", "oversight"}
log_response = client.get(f"/sessions/{session_id}/reactions")
assert log_response.status_code == 200
reactions = log_response.json()
assert len(reactions) == 1
assert reactions[0]["event_id"] == payload["reaction"]["event_id"]
def test_server_rejects_empty_news_ingest() -> None:
app = create_app(session_manager=build_manager())
client = TestClient(app)
session_response = client.post("/sessions", json={"seed": 7})
session_id = session_response.json()["session_id"]
ingest_response = client.post(
f"/sessions/{session_id}/news",
json={"signals": []},
)
assert ingest_response.status_code == 400
assert "At least one external signal is required." in ingest_response.json()["detail"]
def test_server_exposes_provider_diagnostics() -> None:
app = create_app(session_manager=build_manager())
client = TestClient(app)
session_response = client.post("/sessions", json={"seed": 7})
session_id = session_response.json()["session_id"]
diagnostics_response = client.get(f"/sessions/{session_id}/providers/diagnostics")
assert diagnostics_response.status_code == 200
diagnostics = diagnostics_response.json()
us_diagnostics = next(entry for entry in diagnostics["agents"] if entry["agent_id"] == "us")
assert us_diagnostics["agent_id"] == "us"
assert us_diagnostics["status"] in {"idle", "fallback_only"}
assert us_diagnostics["request_count"] == 0
|