| | from __future__ import annotations |
| |
|
| | import threading |
| | from datetime import datetime, timedelta, timezone |
| |
|
| | from trenches_env.benchmark_runner import ScenarioBenchmarkRunner |
| | from trenches_env.rl import DEFAULT_TRAINING_STAGE |
| | from trenches_env.env import FogOfWarDiplomacyEnv |
| | from trenches_env.models import ( |
| | BenchmarkRunRequest, |
| | BenchmarkRunResponse, |
| | IngestNewsRequest, |
| | IngestNewsResponse, |
| | LiveControlRequest, |
| | ProviderDiagnosticsResponse, |
| | ReactionLogEntry, |
| | ScenarioSummary, |
| | SessionState, |
| | SourceMonitorReport, |
| | StepSessionRequest, |
| | StepSessionResponse, |
| | ) |
| | from trenches_env.source_ingestion import SourceHarvester |
| |
|
| |
|
| | class SessionManager: |
| | def __init__(self, env: FogOfWarDiplomacyEnv | None = None) -> None: |
| | self.env = env or FogOfWarDiplomacyEnv() |
| | self._sessions: dict[str, SessionState] = {} |
| | self._lock = threading.RLock() |
| | self._background_tick_seconds = 1.0 |
| | self._background_stop = threading.Event() |
| | self._background_thread: threading.Thread | None = None |
| |
|
| | def start_background_runner(self, tick_interval_seconds: float | None = None) -> None: |
| | with self._lock: |
| | if tick_interval_seconds is not None: |
| | self._background_tick_seconds = max(0.05, tick_interval_seconds) |
| | if self._background_thread is not None and self._background_thread.is_alive(): |
| | return |
| | self._background_stop.clear() |
| | self._background_thread = threading.Thread( |
| | target=self._run_background_loop, |
| | name="trenches-session-manager", |
| | daemon=True, |
| | ) |
| | self._background_thread.start() |
| |
|
| | def stop_background_runner(self) -> None: |
| | self._background_stop.set() |
| | thread = self._background_thread |
| | if thread is not None and thread.is_alive(): |
| | thread.join(timeout=max(1.0, self._background_tick_seconds * 2.0)) |
| | self._background_thread = None |
| |
|
| | def shutdown(self) -> None: |
| | self.stop_background_runner() |
| | self.env.shutdown() |
| |
|
| | def create_session( |
| | self, |
| | seed: int | None = None, |
| | training_agent: str = "us", |
| | training_stage: str = DEFAULT_TRAINING_STAGE, |
| | max_turns: int | None = None, |
| | scenario_id: str | None = None, |
| | replay_id: str | None = None, |
| | replay_start_index: int | None = None, |
| | ) -> SessionState: |
| | with self._lock: |
| | session = self.env.create_session( |
| | seed=seed, |
| | training_agent=training_agent, |
| | training_stage=training_stage, |
| | max_turns=max_turns, |
| | scenario_id=scenario_id, |
| | replay_id=replay_id, |
| | replay_start_index=replay_start_index, |
| | ) |
| | self._sessions[session.session_id] = session |
| | return session |
| |
|
| | def reset_session( |
| | self, |
| | session_id: str, |
| | seed: int | None = None, |
| | training_agent: str = "us", |
| | training_stage: str = DEFAULT_TRAINING_STAGE, |
| | max_turns: int | None = None, |
| | scenario_id: str | None = None, |
| | replay_id: str | None = None, |
| | replay_start_index: int | None = None, |
| | ) -> SessionState: |
| | with self._lock: |
| | self._require_session(session_id) |
| | session = self.env.reset_session( |
| | session_id=session_id, |
| | seed=seed, |
| | training_agent=training_agent, |
| | training_stage=training_stage, |
| | max_turns=max_turns, |
| | scenario_id=scenario_id, |
| | replay_id=replay_id, |
| | replay_start_index=replay_start_index, |
| | ) |
| | self._sessions[session_id] = session |
| | return session |
| |
|
| | def get_session(self, session_id: str) -> SessionState: |
| | with self._lock: |
| | session = self._require_session(session_id) |
| | if session.live.enabled and session.live.auto_step: |
| | refreshed = self.env.maybe_auto_step_live_session(session) |
| | else: |
| | refreshed = self.env.refresh_session_sources(session) |
| | self._sessions[session_id] = refreshed |
| | return refreshed |
| |
|
| | def set_live_mode(self, session_id: str, request: LiveControlRequest) -> SessionState: |
| | with self._lock: |
| | current = self._require_session(session_id) |
| | updated = self.env.configure_live_session(current, request) |
| | self._sessions[session_id] = updated |
| | return updated |
| |
|
| | def step_session(self, session_id: str, request: StepSessionRequest) -> StepSessionResponse: |
| | with self._lock: |
| | current = self._require_session(session_id) |
| | result = self.env.step_session(current, request) |
| | self._sessions[session_id] = result.session |
| | return result |
| |
|
| | def ingest_news(self, session_id: str, request: IngestNewsRequest) -> IngestNewsResponse: |
| | with self._lock: |
| | if not request.signals: |
| | raise ValueError("At least one external signal is required.") |
| | current = self._require_session(session_id) |
| | refreshed = self.env.refresh_session_sources(current) |
| | actions = self.env.resolve_policy_actions( |
| | refreshed, |
| | request.signals, |
| | agent_ids=request.agent_ids or None, |
| | ) |
| | result = self.env.step_session( |
| | refreshed, |
| | StepSessionRequest(actions=actions, external_signals=request.signals), |
| | ) |
| | self._sessions[session_id] = result.session |
| | reaction: ReactionLogEntry | None = result.session.reaction_log[-1] if result.session.reaction_log else None |
| | return IngestNewsResponse( |
| | session=result.session, |
| | oversight=result.oversight, |
| | reaction=reaction, |
| | done=result.done, |
| | ) |
| |
|
| | def refresh_session_sources(self, session_id: str, force: bool = False) -> SessionState: |
| | with self._lock: |
| | current = self._require_session(session_id) |
| | refreshed = self.env.refresh_session_sources(current, force=force) |
| | self._sessions[session_id] = refreshed |
| | return refreshed |
| |
|
| | def source_monitor(self, session_id: str) -> SourceMonitorReport: |
| | with self._lock: |
| | current = self._require_session(session_id) |
| | refreshed = self.env.refresh_session_sources(current) |
| | self._sessions[session_id] = refreshed |
| | return self.env.source_monitor(refreshed) |
| |
|
| | def reaction_log(self, session_id: str) -> list[ReactionLogEntry]: |
| | with self._lock: |
| | current = self._require_session(session_id) |
| | return [entry.model_copy(deep=True) for entry in current.reaction_log] |
| |
|
| | def provider_diagnostics(self, session_id: str) -> ProviderDiagnosticsResponse: |
| | with self._lock: |
| | current = self._require_session(session_id) |
| | refreshed = self.env.refresh_session_sources(current) |
| | self._sessions[session_id] = refreshed |
| | return self.env.provider_diagnostics(refreshed) |
| |
|
| | def list_scenarios(self) -> list[ScenarioSummary]: |
| | return [ |
| | ScenarioSummary( |
| | id=scenario.id, |
| | name=scenario.name, |
| | description=scenario.description, |
| | tags=list(scenario.tags), |
| | benchmark_turns=scenario.benchmark_turns, |
| | benchmark_enabled=scenario.benchmark_enabled, |
| | ) |
| | for scenario in self.env.list_scenarios() |
| | ] |
| |
|
| | def run_benchmark(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse: |
| | runner = ScenarioBenchmarkRunner( |
| | env_factory=lambda: FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False)) |
| | ) |
| | return runner.run(request) |
| |
|
| | def _run_background_loop(self) -> None: |
| | while not self._background_stop.is_set(): |
| | try: |
| | self._tick_live_sessions() |
| | except Exception: |
| | pass |
| | self._background_stop.wait(self._background_tick_seconds) |
| |
|
| | def _tick_live_sessions(self) -> None: |
| | now = datetime.now(timezone.utc) |
| | with self._lock: |
| | for session_id, session in list(self._sessions.items()): |
| | if not self._session_needs_live_tick(session, now): |
| | continue |
| | self._sessions[session_id] = self.env.maybe_auto_step_live_session(session) |
| |
|
| | @staticmethod |
| | def _session_needs_live_tick(session: SessionState, now: datetime) -> bool: |
| | if not session.live.enabled or not session.live.auto_step: |
| | return False |
| | if session.live.last_auto_step_at is None: |
| | return True |
| | interval = timedelta(milliseconds=max(session.live.poll_interval_ms, 1_000)) |
| | return now - session.live.last_auto_step_at >= interval |
| |
|
| | def _require_session(self, session_id: str) -> SessionState: |
| | session = self._sessions.get(session_id) |
| | if session is None: |
| | raise KeyError(session_id) |
| | return session |
| |
|