| from __future__ import annotations | |
| from fastapi.testclient import TestClient | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from trenches_env.env import FogOfWarDiplomacyEnv | |
| from trenches_env.server import DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX, create_app | |
| from trenches_env.session_manager import SessionManager | |
| from trenches_env.source_ingestion import SourceHarvester | |
| def build_manager() -> SessionManager: | |
| env = FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False)) | |
| return SessionManager(env=env) | |
| def _cors_kwargs(app) -> dict[str, object]: | |
| middleware = next(entry for entry in app.user_middleware if entry.cls is CORSMiddleware) | |
| return middleware.kwargs | |
| def test_server_defaults_to_localhost_any_port_cors(monkeypatch) -> None: | |
| monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGINS", raising=False) | |
| monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False) | |
| monkeypatch.delenv("TRENCHES_CORS_ALLOW_CREDENTIALS", raising=False) | |
| app = create_app(session_manager=build_manager()) | |
| cors = _cors_kwargs(app) | |
| assert cors["allow_origins"] == [] | |
| assert cors["allow_origin_regex"] == DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX | |
| assert cors["allow_credentials"] is True | |
| def test_server_honors_explicit_cors_origin_list(monkeypatch) -> None: | |
| monkeypatch.setenv("TRENCHES_CORS_ALLOW_ORIGINS", "https://dashboard.example.com, https://ops.example.com") | |
| monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False) | |
| monkeypatch.setenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "false") | |
| app = create_app(session_manager=build_manager()) | |
| cors = _cors_kwargs(app) | |
| assert cors["allow_origins"] == ["https://dashboard.example.com", "https://ops.example.com"] | |
| assert cors["allow_origin_regex"] is None | |
| assert cors["allow_credentials"] is False | |
| def test_server_disables_credentials_for_wildcard_cors(monkeypatch) -> None: | |
| monkeypatch.setenv("TRENCHES_CORS_ALLOW_ORIGINS", "*") | |
| monkeypatch.delenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX", raising=False) | |
| monkeypatch.setenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "true") | |
| app = create_app(session_manager=build_manager()) | |
| cors = _cors_kwargs(app) | |
| assert cors["allow_origins"] == ["*"] | |
| assert cors["allow_origin_regex"] is None | |
| assert cors["allow_credentials"] is False | |
| def test_server_exposes_scenarios_and_benchmark_endpoints() -> None: | |
| app = create_app(session_manager=build_manager()) | |
| client = TestClient(app) | |
| scenarios_response = client.get("/scenarios") | |
| assert scenarios_response.status_code == 200 | |
| scenarios = scenarios_response.json() | |
| assert any(scenario["id"] == "shipping_crisis" for scenario in scenarios) | |
| benchmark_response = client.post( | |
| "/benchmarks/run", | |
| json={ | |
| "scenario_ids": ["shipping_crisis"], | |
| "seed": 21, | |
| "steps_per_scenario": 2, | |
| }, | |
| ) | |
| assert benchmark_response.status_code == 200 | |
| benchmark = benchmark_response.json() | |
| assert benchmark["scenario_count"] == 1 | |
| assert benchmark["results"][0]["scenario_id"] == "shipping_crisis" | |
| def test_capabilities_expose_model_provider_bindings(monkeypatch) -> None: | |
| monkeypatch.setenv("TRENCHES_MODEL_PROVIDER_US", "huggingface") | |
| monkeypatch.setenv("TRENCHES_MODEL_NAME_US", "openai/gpt-oss-120b") | |
| app = create_app(session_manager=build_manager()) | |
| client = TestClient(app) | |
| response = client.get("/capabilities") | |
| assert response.status_code == 200 | |
| capabilities = response.json() | |
| assert capabilities["model_bindings"]["us"]["configured"] is True | |
| assert capabilities["model_bindings"]["us"]["decision_mode"] == "provider_ready" | |
| assert capabilities["model_bindings"]["us"]["provider"] == "huggingface" | |
| assert capabilities["model_bindings"]["us"]["api_key_env"] == "HF_TOKEN" | |
| assert "negotiate" in capabilities["model_bindings"]["us"]["action_tools"] | |
| def test_server_ingests_news_and_exposes_reaction_log() -> None: | |
| app = create_app(session_manager=build_manager()) | |
| client = TestClient(app) | |
| session_response = client.post("/sessions", json={"seed": 7}) | |
| assert session_response.status_code == 200 | |
| session_id = session_response.json()["session_id"] | |
| ingest_response = client.post( | |
| f"/sessions/{session_id}/news", | |
| json={ | |
| "signals": [ | |
| { | |
| "source": "wire-service", | |
| "headline": "Shipping risk rises in Hormuz after reported drone intercept.", | |
| "region": "gulf", | |
| "tags": ["shipping", "attack"], | |
| "severity": 0.76, | |
| } | |
| ], | |
| "agent_ids": ["us", "gulf", "oversight"], | |
| }, | |
| ) | |
| assert ingest_response.status_code == 200 | |
| payload = ingest_response.json() | |
| assert payload["reaction"] is not None | |
| assert payload["reaction"]["signals"][0]["source"] == "wire-service" | |
| assert payload["reaction"]["latent_event_ids"] | |
| assert payload["session"]["belief_state"]["gulf"]["beliefs"] | |
| assert {outcome["agent_id"] for outcome in payload["reaction"]["actor_outcomes"]} == {"us", "gulf", "oversight"} | |
| log_response = client.get(f"/sessions/{session_id}/reactions") | |
| assert log_response.status_code == 200 | |
| reactions = log_response.json() | |
| assert len(reactions) == 1 | |
| assert reactions[0]["event_id"] == payload["reaction"]["event_id"] | |
| def test_server_rejects_empty_news_ingest() -> None: | |
| app = create_app(session_manager=build_manager()) | |
| client = TestClient(app) | |
| session_response = client.post("/sessions", json={"seed": 7}) | |
| session_id = session_response.json()["session_id"] | |
| ingest_response = client.post( | |
| f"/sessions/{session_id}/news", | |
| json={"signals": []}, | |
| ) | |
| assert ingest_response.status_code == 400 | |
| assert "At least one external signal is required." in ingest_response.json()["detail"] | |
| def test_server_exposes_provider_diagnostics() -> None: | |
| app = create_app(session_manager=build_manager()) | |
| client = TestClient(app) | |
| session_response = client.post("/sessions", json={"seed": 7}) | |
| session_id = session_response.json()["session_id"] | |
| diagnostics_response = client.get(f"/sessions/{session_id}/providers/diagnostics") | |
| assert diagnostics_response.status_code == 200 | |
| diagnostics = diagnostics_response.json() | |
| us_diagnostics = next(entry for entry in diagnostics["agents"] if entry["agent_id"] == "us") | |
| assert us_diagnostics["agent_id"] == "us" | |
| assert us_diagnostics["status"] in {"idle", "fallback_only"} | |
| assert us_diagnostics["request_count"] == 0 | |