| | from __future__ import annotations |
| |
|
| | from typing import Any, Callable |
| |
|
| | from fastapi import FastAPI |
| | from pydantic import BaseModel, Field |
| |
|
| | from trenches_env.agents import AGENT_IDS |
| | from trenches_env.env import FogOfWarDiplomacyEnv |
| | from trenches_env.models import ( |
| | AgentAction, |
| | AgentObservation, |
| | ExternalSignal, |
| | HistoricalEvent, |
| | HistoricalReplayState, |
| | OversightIntervention, |
| | Prediction, |
| | PredictionAssessment, |
| | RewardBreakdown, |
| | SessionState, |
| | StepSessionRequest, |
| | ) |
| | from trenches_env.rl import DEFAULT_MAX_TURNS, DEFAULT_TRAINING_STAGE, TrainingStage |
| | from trenches_env.session_manager import SessionManager |
| | from trenches_env.source_ingestion import SourceHarvester |
| |
|
| | try: |
| | from openenv.core.env_server import create_app as create_openenv_app |
| | from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironmentBase |
| | from openenv.core.env_server.types import Action as OpenEnvActionBase |
| | from openenv.core.env_server.types import Observation as OpenEnvObservationBase |
| | from openenv.core.env_server.types import State as OpenEnvStateBase |
| |
|
| | OPENENV_CORE_AVAILABLE = True |
| | except ImportError: |
| | OPENENV_CORE_AVAILABLE = False |
| |
|
| | class OpenEnvActionBase(BaseModel): |
| | """Fallback shim used when openenv-core is not installed.""" |
| |
|
| | class OpenEnvObservationBase(BaseModel): |
| | """Fallback shim used when openenv-core is not installed.""" |
| |
|
| | reward: float = 0.0 |
| | done: bool = False |
| |
|
| | class OpenEnvStateBase(BaseModel): |
| | """Fallback shim used when openenv-core is not installed.""" |
| |
|
| | episode_id: str = "" |
| | step_count: int = 0 |
| |
|
| | class OpenEnvEnvironmentBase: |
| | """Fallback shim used when openenv-core is not installed.""" |
| |
|
| | SUPPORTS_CONCURRENT_SESSIONS = True |
| |
|
| | @property |
| | def state(self) -> OpenEnvStateBase: |
| | raise NotImplementedError |
| |
|
| | def reset(self, **kwargs: Any) -> OpenEnvObservationBase: |
| | raise NotImplementedError |
| |
|
| | def step(self, action: OpenEnvActionBase, **kwargs: Any) -> OpenEnvObservationBase: |
| | raise NotImplementedError |
| |
|
| | def close(self) -> None: |
| | return None |
| |
|
| | def create_openenv_app(*args: Any, **kwargs: Any) -> FastAPI: |
| | raise RuntimeError("openenv-core is not installed") |
| |
|
| |
|
| | class TrenchesOpenEnvAction(OpenEnvActionBase): |
| | """Structured joint action payload for a single OpenEnv step. |
| | |
| | `action` is convenient for single-policy training. |
| | `actions` enables explicit joint rollouts or self-play. |
| | """ |
| |
|
| | action: AgentAction | None = None |
| | actions: dict[str, AgentAction] = Field(default_factory=dict) |
| | prediction: Prediction | None = None |
| | predictions: dict[str, Prediction] = Field(default_factory=dict) |
| | external_signals: list[ExternalSignal] = Field(default_factory=list) |
| | autofill_missing_with_policy: bool = True |
| | autofill_missing_with_hold: bool = True |
| |
|
| |
|
| | class TrenchesOpenEnvObservation(OpenEnvObservationBase): |
| | session_id: str |
| | training_agent: str |
| | turn: int |
| | training_stage: TrainingStage = DEFAULT_TRAINING_STAGE |
| | agent_observation: AgentObservation |
| | joint_observations: dict[str, AgentObservation] = Field(default_factory=dict) |
| | reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown) |
| | oversight: OversightIntervention = Field(default_factory=OversightIntervention) |
| | historical_replay: HistoricalReplayState = Field(default_factory=HistoricalReplayState) |
| | revealed_event: HistoricalEvent | None = None |
| | prediction_assessments: dict[str, PredictionAssessment] = Field(default_factory=dict) |
| | done_reason: str | None = None |
| |
|
| |
|
| | class TrenchesOpenEnvState(OpenEnvStateBase): |
| | session_id: str = "" |
| | training_agent: str = "us" |
| | training_stage: TrainingStage = DEFAULT_TRAINING_STAGE |
| | max_turns: int = DEFAULT_MAX_TURNS |
| | live_enabled: bool = False |
| | reward_breakdowns: dict[str, RewardBreakdown] = Field(default_factory=dict) |
| | last_oversight: OversightIntervention = Field(default_factory=OversightIntervention) |
| | session: SessionState | None = None |
| |
|
| |
|
| | class TrenchesOpenEnvEnvironment(OpenEnvEnvironmentBase): |
| | """Native OpenEnv environment boundary for the crisis simulator.""" |
| |
|
| | SUPPORTS_CONCURRENT_SESSIONS = True |
| |
|
| | def __init__(self, env: FogOfWarDiplomacyEnv | None = None) -> None: |
| | self._env = env or FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False)) |
| | self._session: SessionState | None = None |
| | self._training_agent = "us" |
| | self._include_joint_observations = False |
| | self._last_oversight = OversightIntervention() |
| |
|
| | @property |
| | def state(self) -> TrenchesOpenEnvState: |
| | if self._session is None: |
| | return TrenchesOpenEnvState( |
| | episode_id="", |
| | step_count=0, |
| | session_id="", |
| | training_agent=self._training_agent, |
| | training_stage=DEFAULT_TRAINING_STAGE, |
| | max_turns=DEFAULT_MAX_TURNS, |
| | ) |
| |
|
| | return TrenchesOpenEnvState( |
| | episode_id=self._session.session_id, |
| | step_count=self._session.world.turn, |
| | session_id=self._session.session_id, |
| | training_agent=self._training_agent, |
| | training_stage=self._session.episode.training_stage, |
| | max_turns=self._session.episode.max_turns, |
| | live_enabled=self._session.live.enabled, |
| | reward_breakdowns=self._session.rewards, |
| | last_oversight=self._last_oversight, |
| | session=self._session, |
| | ) |
| |
|
| | def reset( |
| | self, |
| | *, |
| | seed: int | None = None, |
| | episode_id: str | None = None, |
| | training_agent: str = "us", |
| | training_stage: TrainingStage = DEFAULT_TRAINING_STAGE, |
| | max_turns: int | None = None, |
| | scenario_id: str | None = None, |
| | replay_id: str | None = None, |
| | replay_start_index: int | None = None, |
| | include_joint_observations: bool = False, |
| | **_: Any, |
| | ) -> TrenchesOpenEnvObservation: |
| | self._training_agent = self._validate_training_agent(training_agent) |
| | self._include_joint_observations = include_joint_observations |
| | self._last_oversight = OversightIntervention() |
| | self._session = self._env.create_session( |
| | seed=seed, |
| | session_id=episode_id, |
| | training_agent=self._training_agent, |
| | training_stage=training_stage, |
| | max_turns=max_turns, |
| | scenario_id=scenario_id, |
| | replay_id=replay_id, |
| | replay_start_index=replay_start_index, |
| | ) |
| | return self._build_observation(self._session) |
| |
|
| | def step(self, action: TrenchesOpenEnvAction, **_: Any) -> TrenchesOpenEnvObservation: |
| | if self._session is None: |
| | raise RuntimeError("reset() must be called before step().") |
| |
|
| | joint_actions = self._resolve_joint_actions(action) |
| | result = self._env.step_session( |
| | self._session, |
| | StepSessionRequest( |
| | actions=joint_actions, |
| | predictions=self._resolve_joint_predictions(action), |
| | external_signals=action.external_signals, |
| | ), |
| | ) |
| | self._session = result.session |
| | self._last_oversight = result.oversight |
| | return self._build_observation(result.session) |
| |
|
| | def close(self) -> None: |
| | self._env.shutdown() |
| |
|
| | def _build_observation(self, session: SessionState) -> TrenchesOpenEnvObservation: |
| | reward_breakdown = session.rewards.get(self._training_agent, RewardBreakdown()) |
| | done_reason: str | None = None |
| | if session.world.tension_level >= 95.0: |
| | done_reason = "tension_threshold" |
| | elif session.world.turn >= session.episode.max_turns: |
| | done_reason = "max_turns" |
| | elif ( |
| | session.historical_replay.enabled |
| | and session.historical_replay.current_event_index >= len(session.historical_replay.ground_truth_timeline) - 1 |
| | ): |
| | done_reason = "replay_complete" |
| |
|
| | recent_trace = session.recent_traces[-1] if session.recent_traces else None |
| |
|
| | return TrenchesOpenEnvObservation( |
| | session_id=session.session_id, |
| | training_agent=self._training_agent, |
| | turn=session.world.turn, |
| | training_stage=session.episode.training_stage, |
| | agent_observation=session.observations[self._training_agent], |
| | joint_observations=session.observations if self._include_joint_observations else {}, |
| | reward=reward_breakdown.total, |
| | done=done_reason is not None, |
| | reward_breakdown=reward_breakdown, |
| | oversight=self._last_oversight, |
| | historical_replay=self._public_historical_replay(session.historical_replay), |
| | revealed_event=recent_trace.revealed_event if recent_trace is not None else session.historical_replay.last_revealed_event, |
| | prediction_assessments=recent_trace.prediction_assessments if recent_trace is not None else {}, |
| | done_reason=done_reason, |
| | ) |
| |
|
| | def _resolve_joint_actions(self, action: TrenchesOpenEnvAction) -> dict[str, AgentAction]: |
| | joint_actions = dict(action.actions) |
| | if action.action is not None: |
| | joint_actions[action.action.actor] = action.action |
| |
|
| | unknown_agents = sorted(set(joint_actions) - set(AGENT_IDS)) |
| | if unknown_agents: |
| | raise ValueError(f"Unknown agents in joint action: {', '.join(unknown_agents)}") |
| |
|
| | missing_agents = [agent_id for agent_id in AGENT_IDS if agent_id not in joint_actions] |
| | if missing_agents and action.autofill_missing_with_policy and self._session is not None: |
| | joint_actions = self._env.resolve_policy_actions( |
| | self._session, |
| | action.external_signals, |
| | preset_actions=joint_actions, |
| | agent_ids=missing_agents, |
| | ) |
| | elif action.autofill_missing_with_hold: |
| | for agent_id in missing_agents: |
| | joint_actions[agent_id] = AgentAction( |
| | actor=agent_id, |
| | type="hold", |
| | summary=f"Auto-filled hold for {agent_id} in the OpenEnv step.", |
| | ) |
| |
|
| | return joint_actions |
| |
|
| | def _resolve_joint_predictions(self, action: TrenchesOpenEnvAction) -> dict[str, Prediction]: |
| | joint_predictions = dict(action.predictions) |
| | if action.prediction is not None: |
| | joint_predictions[action.prediction.agent_id] = action.prediction |
| | return joint_predictions |
| |
|
| | @staticmethod |
| | def _validate_training_agent(training_agent: str) -> str: |
| | if training_agent not in AGENT_IDS: |
| | raise ValueError(f"Unknown training_agent: {training_agent}") |
| | return training_agent |
| |
|
| | @staticmethod |
| | def _public_historical_replay(historical_replay: HistoricalReplayState) -> HistoricalReplayState: |
| | if not historical_replay.enabled: |
| | return historical_replay.model_copy(deep=True) |
| | return historical_replay.model_copy( |
| | update={ |
| | "ground_truth_timeline": [], |
| | }, |
| | deep=True, |
| | ) |
| |
|
| |
|
| | def create_openenv_fastapi_app( |
| | env_factory: Callable[[], TrenchesOpenEnvEnvironment] | None = None, |
| | ) -> FastAPI | None: |
| | if not OPENENV_CORE_AVAILABLE: |
| | return None |
| |
|
| | factory = env_factory or TrenchesOpenEnvEnvironment |
| | return create_openenv_app( |
| | factory, |
| | TrenchesOpenEnvAction, |
| | TrenchesOpenEnvObservation, |
| | env_name="trenches", |
| | ) |
| |
|
| |
|
| | class OpenEnvAdapter: |
| | """Legacy tuple-based runtime kept for the existing frontend/session endpoints.""" |
| |
|
| | def __init__(self, session_manager: SessionManager | None = None) -> None: |
| | self.session_manager = session_manager or SessionManager() |
| | self._current_session_id: str | None = None |
| |
|
| | def reset( |
| | self, |
| | seed: int | None = None, |
| | training_agent: str = "us", |
| | training_stage: TrainingStage = DEFAULT_TRAINING_STAGE, |
| | max_turns: int | None = None, |
| | scenario_id: str | None = None, |
| | replay_id: str | None = None, |
| | replay_start_index: int | None = None, |
| | ) -> tuple[dict[str, Any], dict[str, Any]]: |
| | session = self.session_manager.create_session( |
| | seed=seed, |
| | training_agent=training_agent, |
| | training_stage=training_stage, |
| | max_turns=max_turns, |
| | scenario_id=scenario_id, |
| | replay_id=replay_id, |
| | replay_start_index=replay_start_index, |
| | ) |
| | self._current_session_id = session.session_id |
| | return session.observations, self._build_info(session=session, oversight=OversightIntervention()) |
| |
|
| | def step( |
| | self, |
| | actions: dict[str, Any], |
| | predictions: dict[str, Prediction] | None = None, |
| | external_signals: list[ExternalSignal] | None = None, |
| | ) -> tuple[dict[str, Any], dict[str, Any], bool, bool, dict[str, Any]]: |
| | if self._current_session_id is None: |
| | self.reset() |
| |
|
| | result = self.session_manager.step_session( |
| | self._current_session_id, |
| | StepSessionRequest( |
| | actions=actions, |
| | predictions=predictions or {}, |
| | external_signals=external_signals or [], |
| | ), |
| | ) |
| | session = result.session |
| | terminated = session.world.tension_level >= 95.0 |
| | truncated = session.world.turn >= session.episode.max_turns or ( |
| | session.historical_replay.enabled |
| | and session.historical_replay.current_event_index >= len(session.historical_replay.ground_truth_timeline) - 1 |
| | ) |
| | return ( |
| | session.observations, |
| | session.rewards, |
| | terminated, |
| | truncated, |
| | self._build_info(session=session, oversight=result.oversight), |
| | ) |
| |
|
| | def state(self) -> SessionState | None: |
| | if self._current_session_id is None: |
| | return None |
| | return self.session_manager.get_session(self._current_session_id) |
| |
|
| | @staticmethod |
| | def _build_info(session: SessionState, oversight: OversightIntervention) -> dict[str, Any]: |
| | return { |
| | "session_id": session.session_id, |
| | "turn": session.world.turn, |
| | "agent_ids": list(AGENT_IDS), |
| | "world": session.world.model_dump(mode="json"), |
| | "belief_state": {agent_id: state.model_dump(mode="json") for agent_id, state in session.belief_state.items()}, |
| | "episode": session.episode.model_dump(mode="json"), |
| | "oversight": oversight.model_dump(mode="json"), |
| | "live": session.live.model_dump(mode="json"), |
| | "historical_replay": session.historical_replay.model_dump(mode="json"), |
| | "prediction_assessments": [assessment.model_dump(mode="json") for assessment in session.prediction_assessments[-5:]], |
| | } |
| |
|