Spaces:
Paused
Paused
| from __future__ import annotations | |
| from typing import Any, Callable | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| from trenches_env.agents import AGENT_IDS | |
| from trenches_env.env import FogOfWarDiplomacyEnv | |
| from trenches_env.models import ( | |
| AgentAction, | |
| AgentObservation, | |
| ExternalSignal, | |
| HistoricalEvent, | |
| HistoricalReplayState, | |
| OversightIntervention, | |
| Prediction, | |
| PredictionAssessment, | |
| RewardBreakdown, | |
| SessionState, | |
| StepSessionRequest, | |
| ) | |
| from trenches_env.rl import DEFAULT_MAX_TURNS, DEFAULT_TRAINING_STAGE, TrainingStage | |
| from trenches_env.session_manager import SessionManager | |
| from trenches_env.source_ingestion import SourceHarvester | |
| try: | |
| from openenv.core.env_server import create_app as create_openenv_app | |
| from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironmentBase | |
| from openenv.core.env_server.types import Action as OpenEnvActionBase | |
| from openenv.core.env_server.types import Observation as OpenEnvObservationBase | |
| from openenv.core.env_server.types import State as OpenEnvStateBase | |
| OPENENV_CORE_AVAILABLE = True | |
| except ImportError: | |
| OPENENV_CORE_AVAILABLE = False | |
| class OpenEnvActionBase(BaseModel): | |
| """Fallback shim used when openenv-core is not installed.""" | |
| class OpenEnvObservationBase(BaseModel): | |
| """Fallback shim used when openenv-core is not installed.""" | |
| reward: float = 0.0 | |
| done: bool = False | |
| class OpenEnvStateBase(BaseModel): | |
| """Fallback shim used when openenv-core is not installed.""" | |
| episode_id: str = "" | |
| step_count: int = 0 | |
| class OpenEnvEnvironmentBase: | |
| """Fallback shim used when openenv-core is not installed.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def state(self) -> OpenEnvStateBase: | |
| raise NotImplementedError | |
| def reset(self, **kwargs: Any) -> OpenEnvObservationBase: | |
| raise NotImplementedError | |
| def step(self, action: OpenEnvActionBase, **kwargs: Any) -> OpenEnvObservationBase: | |
| raise NotImplementedError | |
| def close(self) -> None: | |
| return None | |
| def create_openenv_app(*args: Any, **kwargs: Any) -> FastAPI: # type: ignore[misc] | |
| raise RuntimeError("openenv-core is not installed") | |
| class TrenchesOpenEnvAction(OpenEnvActionBase): | |
| """Structured joint action payload for a single OpenEnv step. | |
| `action` is convenient for single-policy training. | |
| `actions` enables explicit joint rollouts or self-play. | |
| """ | |
| action: AgentAction | None = None | |
| actions: dict[str, AgentAction] = Field(default_factory=dict) | |
| prediction: Prediction | None = None | |
| predictions: dict[str, Prediction] = Field(default_factory=dict) | |
| external_signals: list[ExternalSignal] = Field(default_factory=list) | |
| autofill_missing_with_policy: bool = True | |
| autofill_missing_with_hold: bool = True | |
| class TrenchesOpenEnvObservation(OpenEnvObservationBase): | |
| session_id: str | |
| training_agent: str | |
| turn: int | |
| training_stage: TrainingStage = DEFAULT_TRAINING_STAGE | |
| agent_observation: AgentObservation | |
| joint_observations: dict[str, AgentObservation] = Field(default_factory=dict) | |
| reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown) | |
| oversight: OversightIntervention = Field(default_factory=OversightIntervention) | |
| historical_replay: HistoricalReplayState = Field(default_factory=HistoricalReplayState) | |
| revealed_event: HistoricalEvent | None = None | |
| prediction_assessments: dict[str, PredictionAssessment] = Field(default_factory=dict) | |
| done_reason: str | None = None | |
| class TrenchesOpenEnvState(OpenEnvStateBase): | |
| session_id: str = "" | |
| training_agent: str = "us" | |
| training_stage: TrainingStage = DEFAULT_TRAINING_STAGE | |
| max_turns: int = DEFAULT_MAX_TURNS | |
| live_enabled: bool = False | |
| reward_breakdowns: dict[str, RewardBreakdown] = Field(default_factory=dict) | |
| last_oversight: OversightIntervention = Field(default_factory=OversightIntervention) | |
| session: SessionState | None = None | |
| class TrenchesOpenEnvEnvironment(OpenEnvEnvironmentBase): | |
| """Native OpenEnv environment boundary for the crisis simulator.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self, env: FogOfWarDiplomacyEnv | None = None) -> None: | |
| self._env = env or FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False)) | |
| self._session: SessionState | None = None | |
| self._training_agent = "us" | |
| self._include_joint_observations = False | |
| self._last_oversight = OversightIntervention() | |
| def state(self) -> TrenchesOpenEnvState: | |
| if self._session is None: | |
| return TrenchesOpenEnvState( | |
| episode_id="", | |
| step_count=0, | |
| session_id="", | |
| training_agent=self._training_agent, | |
| training_stage=DEFAULT_TRAINING_STAGE, | |
| max_turns=DEFAULT_MAX_TURNS, | |
| ) | |
| return TrenchesOpenEnvState( | |
| episode_id=self._session.session_id, | |
| step_count=self._session.world.turn, | |
| session_id=self._session.session_id, | |
| training_agent=self._training_agent, | |
| training_stage=self._session.episode.training_stage, | |
| max_turns=self._session.episode.max_turns, | |
| live_enabled=self._session.live.enabled, | |
| reward_breakdowns=self._session.rewards, | |
| last_oversight=self._last_oversight, | |
| session=self._session, | |
| ) | |
| def reset( | |
| self, | |
| *, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| training_agent: str = "us", | |
| training_stage: TrainingStage = DEFAULT_TRAINING_STAGE, | |
| max_turns: int | None = None, | |
| scenario_id: str | None = None, | |
| replay_id: str | None = None, | |
| replay_start_index: int | None = None, | |
| include_joint_observations: bool = False, | |
| **_: Any, | |
| ) -> TrenchesOpenEnvObservation: | |
| self._training_agent = self._validate_training_agent(training_agent) | |
| self._include_joint_observations = include_joint_observations | |
| self._last_oversight = OversightIntervention() | |
| self._session = self._env.create_session( | |
| seed=seed, | |
| session_id=episode_id, | |
| training_agent=self._training_agent, | |
| training_stage=training_stage, | |
| max_turns=max_turns, | |
| scenario_id=scenario_id, | |
| replay_id=replay_id, | |
| replay_start_index=replay_start_index, | |
| ) | |
| return self._build_observation(self._session) | |
| def step(self, action: TrenchesOpenEnvAction, **_: Any) -> TrenchesOpenEnvObservation: | |
| if self._session is None: | |
| raise RuntimeError("reset() must be called before step().") | |
| joint_actions = self._resolve_joint_actions(action) | |
| result = self._env.step_session( | |
| self._session, | |
| StepSessionRequest( | |
| actions=joint_actions, | |
| predictions=self._resolve_joint_predictions(action), | |
| external_signals=action.external_signals, | |
| ), | |
| ) | |
| self._session = result.session | |
| self._last_oversight = result.oversight | |
| return self._build_observation(result.session) | |
| def close(self) -> None: | |
| self._env.shutdown() | |
| def _build_observation(self, session: SessionState) -> TrenchesOpenEnvObservation: | |
| reward_breakdown = session.rewards.get(self._training_agent, RewardBreakdown()) | |
| done_reason: str | None = None | |
| if session.world.tension_level >= 95.0: | |
| done_reason = "tension_threshold" | |
| elif session.world.turn >= session.episode.max_turns: | |
| done_reason = "max_turns" | |
| elif ( | |
| session.historical_replay.enabled | |
| and session.historical_replay.current_event_index >= len(session.historical_replay.ground_truth_timeline) - 1 | |
| ): | |
| done_reason = "replay_complete" | |
| recent_trace = session.recent_traces[-1] if session.recent_traces else None | |
| return TrenchesOpenEnvObservation( | |
| session_id=session.session_id, | |
| training_agent=self._training_agent, | |
| turn=session.world.turn, | |
| training_stage=session.episode.training_stage, | |
| agent_observation=session.observations[self._training_agent], | |
| joint_observations=session.observations if self._include_joint_observations else {}, | |
| reward=reward_breakdown.total, | |
| done=done_reason is not None, | |
| reward_breakdown=reward_breakdown, | |
| oversight=self._last_oversight, | |
| historical_replay=self._public_historical_replay(session.historical_replay), | |
| revealed_event=recent_trace.revealed_event if recent_trace is not None else session.historical_replay.last_revealed_event, | |
| prediction_assessments=recent_trace.prediction_assessments if recent_trace is not None else {}, | |
| done_reason=done_reason, | |
| ) | |
| def _resolve_joint_actions(self, action: TrenchesOpenEnvAction) -> dict[str, AgentAction]: | |
| joint_actions = dict(action.actions) | |
| if action.action is not None: | |
| joint_actions[action.action.actor] = action.action | |
| unknown_agents = sorted(set(joint_actions) - set(AGENT_IDS)) | |
| if unknown_agents: | |
| raise ValueError(f"Unknown agents in joint action: {', '.join(unknown_agents)}") | |
| missing_agents = [agent_id for agent_id in AGENT_IDS if agent_id not in joint_actions] | |
| if missing_agents and action.autofill_missing_with_policy and self._session is not None: | |
| joint_actions = self._env.resolve_policy_actions( | |
| self._session, | |
| action.external_signals, | |
| preset_actions=joint_actions, | |
| agent_ids=missing_agents, | |
| ) | |
| elif action.autofill_missing_with_hold: | |
| for agent_id in missing_agents: | |
| joint_actions[agent_id] = AgentAction( | |
| actor=agent_id, | |
| type="hold", | |
| summary=f"Auto-filled hold for {agent_id} in the OpenEnv step.", | |
| ) | |
| return joint_actions | |
| def _resolve_joint_predictions(self, action: TrenchesOpenEnvAction) -> dict[str, Prediction]: | |
| joint_predictions = dict(action.predictions) | |
| if action.prediction is not None: | |
| joint_predictions[action.prediction.agent_id] = action.prediction | |
| return joint_predictions | |
| def _validate_training_agent(training_agent: str) -> str: | |
| if training_agent not in AGENT_IDS: | |
| raise ValueError(f"Unknown training_agent: {training_agent}") | |
| return training_agent | |
| def _public_historical_replay(historical_replay: HistoricalReplayState) -> HistoricalReplayState: | |
| if not historical_replay.enabled: | |
| return historical_replay.model_copy(deep=True) | |
| return historical_replay.model_copy( | |
| update={ | |
| "ground_truth_timeline": [], | |
| }, | |
| deep=True, | |
| ) | |
| def create_openenv_fastapi_app( | |
| env_factory: Callable[[], TrenchesOpenEnvEnvironment] | None = None, | |
| ) -> FastAPI | None: | |
| if not OPENENV_CORE_AVAILABLE: | |
| return None | |
| factory = env_factory or TrenchesOpenEnvEnvironment | |
| return create_openenv_app( | |
| factory, | |
| TrenchesOpenEnvAction, | |
| TrenchesOpenEnvObservation, | |
| env_name="trenches", | |
| ) | |
| class OpenEnvAdapter: | |
| """Legacy tuple-based runtime kept for the existing frontend/session endpoints.""" | |
| def __init__(self, session_manager: SessionManager | None = None) -> None: | |
| self.session_manager = session_manager or SessionManager() | |
| self._current_session_id: str | None = None | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| training_agent: str = "us", | |
| training_stage: TrainingStage = DEFAULT_TRAINING_STAGE, | |
| max_turns: int | None = None, | |
| scenario_id: str | None = None, | |
| replay_id: str | None = None, | |
| replay_start_index: int | None = None, | |
| ) -> tuple[dict[str, Any], dict[str, Any]]: | |
| session = self.session_manager.create_session( | |
| seed=seed, | |
| training_agent=training_agent, | |
| training_stage=training_stage, | |
| max_turns=max_turns, | |
| scenario_id=scenario_id, | |
| replay_id=replay_id, | |
| replay_start_index=replay_start_index, | |
| ) | |
| self._current_session_id = session.session_id | |
| return session.observations, self._build_info(session=session, oversight=OversightIntervention()) | |
| def step( | |
| self, | |
| actions: dict[str, Any], | |
| predictions: dict[str, Prediction] | None = None, | |
| external_signals: list[ExternalSignal] | None = None, | |
| ) -> tuple[dict[str, Any], dict[str, Any], bool, bool, dict[str, Any]]: | |
| if self._current_session_id is None: | |
| self.reset() | |
| result = self.session_manager.step_session( | |
| self._current_session_id, | |
| StepSessionRequest( | |
| actions=actions, | |
| predictions=predictions or {}, | |
| external_signals=external_signals or [], | |
| ), | |
| ) | |
| session = result.session | |
| terminated = session.world.tension_level >= 95.0 | |
| truncated = session.world.turn >= session.episode.max_turns or ( | |
| session.historical_replay.enabled | |
| and session.historical_replay.current_event_index >= len(session.historical_replay.ground_truth_timeline) - 1 | |
| ) | |
| return ( | |
| session.observations, | |
| session.rewards, | |
| terminated, | |
| truncated, | |
| self._build_info(session=session, oversight=result.oversight), | |
| ) | |
| def state(self) -> SessionState | None: | |
| if self._current_session_id is None: | |
| return None | |
| return self.session_manager.get_session(self._current_session_id) | |
| def _build_info(session: SessionState, oversight: OversightIntervention) -> dict[str, Any]: | |
| return { | |
| "session_id": session.session_id, | |
| "turn": session.world.turn, | |
| "agent_ids": list(AGENT_IDS), | |
| "world": session.world.model_dump(mode="json"), | |
| "belief_state": {agent_id: state.model_dump(mode="json") for agent_id, state in session.belief_state.items()}, | |
| "episode": session.episode.model_dump(mode="json"), | |
| "oversight": oversight.model_dump(mode="json"), | |
| "live": session.live.model_dump(mode="json"), | |
| "historical_replay": session.historical_replay.model_dump(mode="json"), | |
| "prediction_assessments": [assessment.model_dump(mode="json") for assessment in session.prediction_assessments[-5:]], | |
| } | |