trenches / backend /tests /test_session_manager.py
Codex
sync main snapshot for HF Space
1794757
import time
from datetime import datetime, timedelta, timezone
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.models import (
AgentAction,
BenchmarkRunRequest,
ExternalSignal,
IngestNewsRequest,
LiveControlRequest,
StepSessionRequest,
)
from trenches_env.session_manager import SessionManager
from trenches_env.source_bundles import AGENT_LIVE_SOURCE_BUNDLES
from trenches_env.source_ingestion import SourceHarvester
class ShippingFeedFetcher:
def fetch(self, _: str) -> tuple[str, str]:
return (
"""
<rss>
<channel>
<title>Maritime Watch</title>
<item>
<title>Shipping risk rises in Hormuz after drone intercept near tanker lanes</title>
</item>
</channel>
</rss>
""",
"application/rss+xml",
)
def build_live_manager() -> SessionManager:
harvester = SourceHarvester(fetcher=ShippingFeedFetcher(), auto_start=False)
env = FogOfWarDiplomacyEnv(source_harvester=harvester)
return SessionManager(env=env)
def test_session_lifecycle() -> None:
manager = SessionManager()
session = manager.create_session(seed=7)
assert session.world.turn == 0
assert session.session_id
assert session.episode.max_turns == 1000
live_session = manager.set_live_mode(
session.session_id,
LiveControlRequest(enabled=True, auto_step=False, poll_interval_ms=15_000),
)
assert live_session.live.enabled is True
assert live_session.live.source_queue_sizes["us"] == len(AGENT_LIVE_SOURCE_BUNDLES["us"])
response = manager.step_session(
session.session_id,
StepSessionRequest(
actions={
"us": AgentAction(
actor="us",
type="negotiate",
target="gulf",
summary="Offer deconfliction and shipping guarantees.",
),
"oversight": AgentAction(
actor="oversight",
type="oversight_review",
summary="Monitor for escalation drift.",
),
}
),
)
assert response.session.world.turn == 1
assert "gulf" in response.session.world.coalition_graph["us"]
assert "us" in response.session.observations
assert response.session.observations["us"].training_source_bundle
assert response.session.observations["us"].live_source_bundle
assert response.session.recent_traces
assert response.session.recent_traces[-1].turn == 1
assert response.session.model_bindings["us"].action_tools
assert response.session.model_bindings["us"].decision_mode == "heuristic_fallback"
assert response.session.action_log
assert {entry.actor for entry in response.session.action_log[-2:]} == {"us", "oversight"}
def test_stage_1_disables_live_mode() -> None:
manager = SessionManager()
session = manager.create_session(seed=7, training_stage="stage_1_dense")
assert session.episode.training_stage == "stage_1_dense"
assert session.episode.fog_of_war is False
assert session.observations["us"].perceived_tension == session.world.tension_level
try:
manager.set_live_mode(
session.session_id,
LiveControlRequest(enabled=True, auto_step=False, poll_interval_ms=15_000),
)
except ValueError:
pass
else:
raise AssertionError("stage_1_dense sessions should reject live mode")
def test_session_manager_creates_named_scenarios() -> None:
manager = SessionManager()
session = manager.create_session(seed=7, scenario_id="shipping_crisis")
assert session.episode.scenario_id == "shipping_crisis"
assert session.episode.scenario_name == "Shipping Crisis"
assert session.world.actor_state["gulf"]["shipping_continuity"] < 78.0
def test_source_monitor_report_is_available_for_sessions() -> None:
manager = SessionManager()
session = manager.create_session(seed=7)
report = manager.source_monitor(session.session_id)
assert report.session_id == session.session_id
assert len(report.agents) == len(AGENT_LIVE_SOURCE_BUNDLES)
assert report.summary.active_source_count > 0
def test_live_get_session_auto_steps_once_for_new_source_packets() -> None:
manager = build_live_manager()
session = manager.create_session(seed=7)
manager.set_live_mode(
session.session_id,
LiveControlRequest(enabled=True, auto_step=True, poll_interval_ms=1_000),
)
first_live_tick = manager.get_session(session.session_id)
assert first_live_tick.world.turn == 1
assert first_live_tick.live.last_auto_step_at is not None
assert first_live_tick.live.reacted_packet_fetched_at
assert first_live_tick.recent_traces[-1].actions["gulf"].type == "defend"
assert first_live_tick.recent_traces[-1].actions["oversight"].type == "oversight_review"
for observation in first_live_tick.observations.values():
for packet in observation.source_packets:
if packet.fetched_at is not None:
manager._sessions[session.session_id].live.reacted_packet_fetched_at[packet.source_id] = packet.fetched_at
manager._sessions[session.session_id].live.last_auto_step_at = datetime.now(timezone.utc) - timedelta(seconds=2)
second_live_tick = manager.get_session(session.session_id)
assert second_live_tick.world.turn == 1
def test_oversight_replaces_escalatory_actions_with_valid_overrides() -> None:
manager = SessionManager()
session = manager.create_session(seed=7)
response = manager.step_session(
session.session_id,
StepSessionRequest(
actions={
"us": AgentAction(actor="us", type="strike", target="iran", summary="Strike escalation lane."),
"israel": AgentAction(actor="israel", type="strike", target="hezbollah", summary="Strike escalation lane."),
"iran": AgentAction(actor="iran", type="mobilize", summary="Mobilize escalation lane."),
"hezbollah": AgentAction(actor="hezbollah", type="deceive", summary="Deception escalation lane."),
}
),
)
assert response.oversight.triggered is True
assert response.oversight.action_override
for agent_id, override_action in response.oversight.action_override.items():
assert override_action.type in {"hold", "negotiate", "defend", "intel_query"}
assert response.session.recent_traces[-1].actions[agent_id].type == override_action.type
assert response.session.world.last_actions
def test_background_runner_advances_live_sessions_without_dashboard_polling() -> None:
manager = build_live_manager()
session = manager.create_session(seed=7)
manager.set_live_mode(
session.session_id,
LiveControlRequest(enabled=True, auto_step=True, poll_interval_ms=1_000),
)
manager.start_background_runner(tick_interval_seconds=0.05)
try:
deadline = time.time() + 1.5
while time.time() < deadline:
current = manager._sessions[session.session_id]
if current.world.turn >= 1:
break
time.sleep(0.05)
current = manager._sessions[session.session_id]
finally:
manager.stop_background_runner()
assert current.world.turn >= 1
assert current.live.last_auto_step_at is not None
assert current.reaction_log
def test_ingest_news_generates_structured_reaction_log() -> None:
manager = SessionManager()
session = manager.create_session(seed=7)
response = manager.ingest_news(
session.session_id,
IngestNewsRequest(
signals=[
ExternalSignal(
source="wire-service",
headline="Shipping risk rises in Hormuz after reported drone intercept.",
region="gulf",
tags=["shipping", "attack"],
severity=0.76,
)
],
agent_ids=["us", "gulf", "oversight"],
),
)
assert response.session.world.turn == 1
assert response.reaction is not None
assert response.reaction.turn == 1
assert response.reaction.signals[0].source == "wire-service"
assert response.reaction.latent_event_ids
assert {outcome.agent_id for outcome in response.reaction.actor_outcomes} == {"us", "gulf", "oversight"}
assert all(outcome.action.metadata["mode"] in {"heuristic_fallback", "provider_inference"} for outcome in response.reaction.actor_outcomes)
assert response.session.reaction_log[-1].event_id == response.reaction.event_id
assert manager.reaction_log(session.session_id)[-1].event_id == response.reaction.event_id
assert response.session.belief_state["gulf"].beliefs
assert response.session.observations["gulf"].belief_brief
def test_provider_diagnostics_are_available_per_session() -> None:
manager = SessionManager()
session = manager.create_session(seed=7)
diagnostics = manager.provider_diagnostics(session.session_id)
us_diagnostics = next(entry for entry in diagnostics.agents if entry.agent_id == "us")
assert us_diagnostics.agent_id == "us"
assert us_diagnostics.status in {"idle", "fallback_only"}
assert us_diagnostics.request_count == 0
def test_session_manager_lists_scenarios_and_runs_benchmarks() -> None:
manager = SessionManager()
scenarios = manager.list_scenarios()
assert any(scenario.id == "shipping_crisis" for scenario in scenarios)
result = manager.run_benchmark(
BenchmarkRunRequest(
scenario_ids=["shipping_crisis"],
seed=9,
steps_per_scenario=3,
)
)
assert result.scenario_count == 1
assert result.results[0].scenario_id == "shipping_crisis"