Spaces:
Paused
Paused
| from __future__ import annotations | |
| import threading | |
| from datetime import datetime, timedelta, timezone | |
| from trenches_env.benchmark_runner import ScenarioBenchmarkRunner | |
| from trenches_env.rl import DEFAULT_TRAINING_STAGE | |
| from trenches_env.env import FogOfWarDiplomacyEnv | |
| from trenches_env.models import ( | |
| BenchmarkRunRequest, | |
| BenchmarkRunResponse, | |
| IngestNewsRequest, | |
| IngestNewsResponse, | |
| LiveControlRequest, | |
| ProviderDiagnosticsResponse, | |
| ReactionLogEntry, | |
| ScenarioSummary, | |
| SessionState, | |
| SourceMonitorReport, | |
| StepSessionRequest, | |
| StepSessionResponse, | |
| ) | |
| from trenches_env.source_ingestion import SourceHarvester | |
| class SessionManager: | |
| def __init__(self, env: FogOfWarDiplomacyEnv | None = None) -> None: | |
| self.env = env or FogOfWarDiplomacyEnv() | |
| self._sessions: dict[str, SessionState] = {} | |
| self._lock = threading.RLock() | |
| self._background_tick_seconds = 1.0 | |
| self._background_stop = threading.Event() | |
| self._background_thread: threading.Thread | None = None | |
| def start_background_runner(self, tick_interval_seconds: float | None = None) -> None: | |
| with self._lock: | |
| if tick_interval_seconds is not None: | |
| self._background_tick_seconds = max(0.05, tick_interval_seconds) | |
| if self._background_thread is not None and self._background_thread.is_alive(): | |
| return | |
| self._background_stop.clear() | |
| self._background_thread = threading.Thread( | |
| target=self._run_background_loop, | |
| name="trenches-session-manager", | |
| daemon=True, | |
| ) | |
| self._background_thread.start() | |
| def stop_background_runner(self) -> None: | |
| self._background_stop.set() | |
| thread = self._background_thread | |
| if thread is not None and thread.is_alive(): | |
| thread.join(timeout=max(1.0, self._background_tick_seconds * 2.0)) | |
| self._background_thread = None | |
| def shutdown(self) -> None: | |
| self.stop_background_runner() | |
| self.env.shutdown() | |
| def create_session( | |
| self, | |
| seed: int | None = None, | |
| training_agent: str = "us", | |
| training_stage: str = DEFAULT_TRAINING_STAGE, | |
| max_turns: int | None = None, | |
| scenario_id: str | None = None, | |
| replay_id: str | None = None, | |
| replay_start_index: int | None = None, | |
| ) -> SessionState: | |
| with self._lock: | |
| session = self.env.create_session( | |
| seed=seed, | |
| training_agent=training_agent, | |
| training_stage=training_stage, | |
| max_turns=max_turns, | |
| scenario_id=scenario_id, | |
| replay_id=replay_id, | |
| replay_start_index=replay_start_index, | |
| ) | |
| self._sessions[session.session_id] = session | |
| return session | |
| def reset_session( | |
| self, | |
| session_id: str, | |
| seed: int | None = None, | |
| training_agent: str = "us", | |
| training_stage: str = DEFAULT_TRAINING_STAGE, | |
| max_turns: int | None = None, | |
| scenario_id: str | None = None, | |
| replay_id: str | None = None, | |
| replay_start_index: int | None = None, | |
| ) -> SessionState: | |
| with self._lock: | |
| self._require_session(session_id) | |
| session = self.env.reset_session( | |
| session_id=session_id, | |
| seed=seed, | |
| training_agent=training_agent, | |
| training_stage=training_stage, | |
| max_turns=max_turns, | |
| scenario_id=scenario_id, | |
| replay_id=replay_id, | |
| replay_start_index=replay_start_index, | |
| ) | |
| self._sessions[session_id] = session | |
| return session | |
| def get_session(self, session_id: str) -> SessionState: | |
| with self._lock: | |
| session = self._require_session(session_id) | |
| if session.live.enabled and session.live.auto_step: | |
| refreshed = self.env.maybe_auto_step_live_session(session) | |
| else: | |
| refreshed = self.env.refresh_session_sources(session) | |
| self._sessions[session_id] = refreshed | |
| return refreshed | |
| def set_live_mode(self, session_id: str, request: LiveControlRequest) -> SessionState: | |
| with self._lock: | |
| current = self._require_session(session_id) | |
| updated = self.env.configure_live_session(current, request) | |
| self._sessions[session_id] = updated | |
| return updated | |
| def step_session(self, session_id: str, request: StepSessionRequest) -> StepSessionResponse: | |
| with self._lock: | |
| current = self._require_session(session_id) | |
| result = self.env.step_session(current, request) | |
| self._sessions[session_id] = result.session | |
| return result | |
| def ingest_news(self, session_id: str, request: IngestNewsRequest) -> IngestNewsResponse: | |
| with self._lock: | |
| if not request.signals: | |
| raise ValueError("At least one external signal is required.") | |
| current = self._require_session(session_id) | |
| refreshed = self.env.refresh_session_sources(current) | |
| actions = self.env.resolve_policy_actions( | |
| refreshed, | |
| request.signals, | |
| agent_ids=request.agent_ids or None, | |
| ) | |
| result = self.env.step_session( | |
| refreshed, | |
| StepSessionRequest(actions=actions, external_signals=request.signals), | |
| ) | |
| self._sessions[session_id] = result.session | |
| reaction: ReactionLogEntry | None = result.session.reaction_log[-1] if result.session.reaction_log else None | |
| return IngestNewsResponse( | |
| session=result.session, | |
| oversight=result.oversight, | |
| reaction=reaction, | |
| done=result.done, | |
| ) | |
| def refresh_session_sources(self, session_id: str, force: bool = False) -> SessionState: | |
| with self._lock: | |
| current = self._require_session(session_id) | |
| refreshed = self.env.refresh_session_sources(current, force=force) | |
| self._sessions[session_id] = refreshed | |
| return refreshed | |
| def source_monitor(self, session_id: str) -> SourceMonitorReport: | |
| with self._lock: | |
| current = self._require_session(session_id) | |
| refreshed = self.env.refresh_session_sources(current) | |
| self._sessions[session_id] = refreshed | |
| return self.env.source_monitor(refreshed) | |
| def reaction_log(self, session_id: str) -> list[ReactionLogEntry]: | |
| with self._lock: | |
| current = self._require_session(session_id) | |
| return [entry.model_copy(deep=True) for entry in current.reaction_log] | |
| def provider_diagnostics(self, session_id: str) -> ProviderDiagnosticsResponse: | |
| with self._lock: | |
| current = self._require_session(session_id) | |
| refreshed = self.env.refresh_session_sources(current) | |
| self._sessions[session_id] = refreshed | |
| return self.env.provider_diagnostics(refreshed) | |
| def list_scenarios(self) -> list[ScenarioSummary]: | |
| return [ | |
| ScenarioSummary( | |
| id=scenario.id, | |
| name=scenario.name, | |
| description=scenario.description, | |
| tags=list(scenario.tags), | |
| benchmark_turns=scenario.benchmark_turns, | |
| benchmark_enabled=scenario.benchmark_enabled, | |
| ) | |
| for scenario in self.env.list_scenarios() | |
| ] | |
| def run_benchmark(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse: | |
| runner = ScenarioBenchmarkRunner( | |
| env_factory=lambda: FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False)) | |
| ) | |
| return runner.run(request) | |
| def _run_background_loop(self) -> None: | |
| while not self._background_stop.is_set(): | |
| try: | |
| self._tick_live_sessions() | |
| except Exception: | |
| pass | |
| self._background_stop.wait(self._background_tick_seconds) | |
| def _tick_live_sessions(self) -> None: | |
| now = datetime.now(timezone.utc) | |
| with self._lock: | |
| for session_id, session in list(self._sessions.items()): | |
| if not self._session_needs_live_tick(session, now): | |
| continue | |
| self._sessions[session_id] = self.env.maybe_auto_step_live_session(session) | |
| def _session_needs_live_tick(session: SessionState, now: datetime) -> bool: | |
| if not session.live.enabled or not session.live.auto_step: | |
| return False | |
| if session.live.last_auto_step_at is None: | |
| return True | |
| interval = timedelta(milliseconds=max(session.live.poll_interval_ms, 1_000)) | |
| return now - session.live.last_auto_step_at >= interval | |
| def _require_session(self, session_id: str) -> SessionState: | |
| session = self._sessions.get(session_id) | |
| if session is None: | |
| raise KeyError(session_id) | |
| return session | |