trenches / backend /src /trenches_env /session_manager.py
Codex
sync main snapshot for HF Space
1794757
from __future__ import annotations
import threading
from datetime import datetime, timedelta, timezone
from trenches_env.benchmark_runner import ScenarioBenchmarkRunner
from trenches_env.rl import DEFAULT_TRAINING_STAGE
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.models import (
BenchmarkRunRequest,
BenchmarkRunResponse,
IngestNewsRequest,
IngestNewsResponse,
LiveControlRequest,
ProviderDiagnosticsResponse,
ReactionLogEntry,
ScenarioSummary,
SessionState,
SourceMonitorReport,
StepSessionRequest,
StepSessionResponse,
)
from trenches_env.source_ingestion import SourceHarvester
class SessionManager:
def __init__(self, env: FogOfWarDiplomacyEnv | None = None) -> None:
self.env = env or FogOfWarDiplomacyEnv()
self._sessions: dict[str, SessionState] = {}
self._lock = threading.RLock()
self._background_tick_seconds = 1.0
self._background_stop = threading.Event()
self._background_thread: threading.Thread | None = None
def start_background_runner(self, tick_interval_seconds: float | None = None) -> None:
with self._lock:
if tick_interval_seconds is not None:
self._background_tick_seconds = max(0.05, tick_interval_seconds)
if self._background_thread is not None and self._background_thread.is_alive():
return
self._background_stop.clear()
self._background_thread = threading.Thread(
target=self._run_background_loop,
name="trenches-session-manager",
daemon=True,
)
self._background_thread.start()
def stop_background_runner(self) -> None:
self._background_stop.set()
thread = self._background_thread
if thread is not None and thread.is_alive():
thread.join(timeout=max(1.0, self._background_tick_seconds * 2.0))
self._background_thread = None
def shutdown(self) -> None:
self.stop_background_runner()
self.env.shutdown()
def create_session(
self,
seed: int | None = None,
training_agent: str = "us",
training_stage: str = DEFAULT_TRAINING_STAGE,
max_turns: int | None = None,
scenario_id: str | None = None,
replay_id: str | None = None,
replay_start_index: int | None = None,
) -> SessionState:
with self._lock:
session = self.env.create_session(
seed=seed,
training_agent=training_agent,
training_stage=training_stage,
max_turns=max_turns,
scenario_id=scenario_id,
replay_id=replay_id,
replay_start_index=replay_start_index,
)
self._sessions[session.session_id] = session
return session
def reset_session(
self,
session_id: str,
seed: int | None = None,
training_agent: str = "us",
training_stage: str = DEFAULT_TRAINING_STAGE,
max_turns: int | None = None,
scenario_id: str | None = None,
replay_id: str | None = None,
replay_start_index: int | None = None,
) -> SessionState:
with self._lock:
self._require_session(session_id)
session = self.env.reset_session(
session_id=session_id,
seed=seed,
training_agent=training_agent,
training_stage=training_stage,
max_turns=max_turns,
scenario_id=scenario_id,
replay_id=replay_id,
replay_start_index=replay_start_index,
)
self._sessions[session_id] = session
return session
def get_session(self, session_id: str) -> SessionState:
with self._lock:
session = self._require_session(session_id)
if session.live.enabled and session.live.auto_step:
refreshed = self.env.maybe_auto_step_live_session(session)
else:
refreshed = self.env.refresh_session_sources(session)
self._sessions[session_id] = refreshed
return refreshed
def set_live_mode(self, session_id: str, request: LiveControlRequest) -> SessionState:
with self._lock:
current = self._require_session(session_id)
updated = self.env.configure_live_session(current, request)
self._sessions[session_id] = updated
return updated
def step_session(self, session_id: str, request: StepSessionRequest) -> StepSessionResponse:
with self._lock:
current = self._require_session(session_id)
result = self.env.step_session(current, request)
self._sessions[session_id] = result.session
return result
def ingest_news(self, session_id: str, request: IngestNewsRequest) -> IngestNewsResponse:
with self._lock:
if not request.signals:
raise ValueError("At least one external signal is required.")
current = self._require_session(session_id)
refreshed = self.env.refresh_session_sources(current)
actions = self.env.resolve_policy_actions(
refreshed,
request.signals,
agent_ids=request.agent_ids or None,
)
result = self.env.step_session(
refreshed,
StepSessionRequest(actions=actions, external_signals=request.signals),
)
self._sessions[session_id] = result.session
reaction: ReactionLogEntry | None = result.session.reaction_log[-1] if result.session.reaction_log else None
return IngestNewsResponse(
session=result.session,
oversight=result.oversight,
reaction=reaction,
done=result.done,
)
def refresh_session_sources(self, session_id: str, force: bool = False) -> SessionState:
with self._lock:
current = self._require_session(session_id)
refreshed = self.env.refresh_session_sources(current, force=force)
self._sessions[session_id] = refreshed
return refreshed
def source_monitor(self, session_id: str) -> SourceMonitorReport:
with self._lock:
current = self._require_session(session_id)
refreshed = self.env.refresh_session_sources(current)
self._sessions[session_id] = refreshed
return self.env.source_monitor(refreshed)
def reaction_log(self, session_id: str) -> list[ReactionLogEntry]:
with self._lock:
current = self._require_session(session_id)
return [entry.model_copy(deep=True) for entry in current.reaction_log]
def provider_diagnostics(self, session_id: str) -> ProviderDiagnosticsResponse:
with self._lock:
current = self._require_session(session_id)
refreshed = self.env.refresh_session_sources(current)
self._sessions[session_id] = refreshed
return self.env.provider_diagnostics(refreshed)
def list_scenarios(self) -> list[ScenarioSummary]:
return [
ScenarioSummary(
id=scenario.id,
name=scenario.name,
description=scenario.description,
tags=list(scenario.tags),
benchmark_turns=scenario.benchmark_turns,
benchmark_enabled=scenario.benchmark_enabled,
)
for scenario in self.env.list_scenarios()
]
def run_benchmark(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse:
runner = ScenarioBenchmarkRunner(
env_factory=lambda: FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))
)
return runner.run(request)
def _run_background_loop(self) -> None:
while not self._background_stop.is_set():
try:
self._tick_live_sessions()
except Exception:
pass
self._background_stop.wait(self._background_tick_seconds)
def _tick_live_sessions(self) -> None:
now = datetime.now(timezone.utc)
with self._lock:
for session_id, session in list(self._sessions.items()):
if not self._session_needs_live_tick(session, now):
continue
self._sessions[session_id] = self.env.maybe_auto_step_live_session(session)
@staticmethod
def _session_needs_live_tick(session: SessionState, now: datetime) -> bool:
if not session.live.enabled or not session.live.auto_step:
return False
if session.live.last_auto_step_at is None:
return True
interval = timedelta(milliseconds=max(session.live.poll_interval_ms, 1_000))
return now - session.live.last_auto_step_at >= interval
def _require_session(self, session_id: str) -> SessionState:
session = self._sessions.get(session_id)
if session is None:
raise KeyError(session_id)
return session