trenches / backend /src /trenches_env /openenv_adapter.py
Codex
sync main snapshot for HF Space
1794757
from __future__ import annotations
from typing import Any, Callable
from fastapi import FastAPI
from pydantic import BaseModel, Field
from trenches_env.agents import AGENT_IDS
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.models import (
AgentAction,
AgentObservation,
ExternalSignal,
HistoricalEvent,
HistoricalReplayState,
OversightIntervention,
Prediction,
PredictionAssessment,
RewardBreakdown,
SessionState,
StepSessionRequest,
)
from trenches_env.rl import DEFAULT_MAX_TURNS, DEFAULT_TRAINING_STAGE, TrainingStage
from trenches_env.session_manager import SessionManager
from trenches_env.source_ingestion import SourceHarvester
try:
from openenv.core.env_server import create_app as create_openenv_app
from openenv.core.env_server.interfaces import Environment as OpenEnvEnvironmentBase
from openenv.core.env_server.types import Action as OpenEnvActionBase
from openenv.core.env_server.types import Observation as OpenEnvObservationBase
from openenv.core.env_server.types import State as OpenEnvStateBase
OPENENV_CORE_AVAILABLE = True
except ImportError:
OPENENV_CORE_AVAILABLE = False
class OpenEnvActionBase(BaseModel):
"""Fallback shim used when openenv-core is not installed."""
class OpenEnvObservationBase(BaseModel):
"""Fallback shim used when openenv-core is not installed."""
reward: float = 0.0
done: bool = False
class OpenEnvStateBase(BaseModel):
"""Fallback shim used when openenv-core is not installed."""
episode_id: str = ""
step_count: int = 0
class OpenEnvEnvironmentBase:
"""Fallback shim used when openenv-core is not installed."""
SUPPORTS_CONCURRENT_SESSIONS = True
@property
def state(self) -> OpenEnvStateBase:
raise NotImplementedError
def reset(self, **kwargs: Any) -> OpenEnvObservationBase:
raise NotImplementedError
def step(self, action: OpenEnvActionBase, **kwargs: Any) -> OpenEnvObservationBase:
raise NotImplementedError
def close(self) -> None:
return None
def create_openenv_app(*args: Any, **kwargs: Any) -> FastAPI: # type: ignore[misc]
raise RuntimeError("openenv-core is not installed")
class TrenchesOpenEnvAction(OpenEnvActionBase):
"""Structured joint action payload for a single OpenEnv step.
`action` is convenient for single-policy training.
`actions` enables explicit joint rollouts or self-play.
"""
action: AgentAction | None = None
actions: dict[str, AgentAction] = Field(default_factory=dict)
prediction: Prediction | None = None
predictions: dict[str, Prediction] = Field(default_factory=dict)
external_signals: list[ExternalSignal] = Field(default_factory=list)
autofill_missing_with_policy: bool = True
autofill_missing_with_hold: bool = True
class TrenchesOpenEnvObservation(OpenEnvObservationBase):
session_id: str
training_agent: str
turn: int
training_stage: TrainingStage = DEFAULT_TRAINING_STAGE
agent_observation: AgentObservation
joint_observations: dict[str, AgentObservation] = Field(default_factory=dict)
reward_breakdown: RewardBreakdown = Field(default_factory=RewardBreakdown)
oversight: OversightIntervention = Field(default_factory=OversightIntervention)
historical_replay: HistoricalReplayState = Field(default_factory=HistoricalReplayState)
revealed_event: HistoricalEvent | None = None
prediction_assessments: dict[str, PredictionAssessment] = Field(default_factory=dict)
done_reason: str | None = None
class TrenchesOpenEnvState(OpenEnvStateBase):
session_id: str = ""
training_agent: str = "us"
training_stage: TrainingStage = DEFAULT_TRAINING_STAGE
max_turns: int = DEFAULT_MAX_TURNS
live_enabled: bool = False
reward_breakdowns: dict[str, RewardBreakdown] = Field(default_factory=dict)
last_oversight: OversightIntervention = Field(default_factory=OversightIntervention)
session: SessionState | None = None
class TrenchesOpenEnvEnvironment(OpenEnvEnvironmentBase):
"""Native OpenEnv environment boundary for the crisis simulator."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, env: FogOfWarDiplomacyEnv | None = None) -> None:
self._env = env or FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))
self._session: SessionState | None = None
self._training_agent = "us"
self._include_joint_observations = False
self._last_oversight = OversightIntervention()
@property
def state(self) -> TrenchesOpenEnvState:
if self._session is None:
return TrenchesOpenEnvState(
episode_id="",
step_count=0,
session_id="",
training_agent=self._training_agent,
training_stage=DEFAULT_TRAINING_STAGE,
max_turns=DEFAULT_MAX_TURNS,
)
return TrenchesOpenEnvState(
episode_id=self._session.session_id,
step_count=self._session.world.turn,
session_id=self._session.session_id,
training_agent=self._training_agent,
training_stage=self._session.episode.training_stage,
max_turns=self._session.episode.max_turns,
live_enabled=self._session.live.enabled,
reward_breakdowns=self._session.rewards,
last_oversight=self._last_oversight,
session=self._session,
)
def reset(
self,
*,
seed: int | None = None,
episode_id: str | None = None,
training_agent: str = "us",
training_stage: TrainingStage = DEFAULT_TRAINING_STAGE,
max_turns: int | None = None,
scenario_id: str | None = None,
replay_id: str | None = None,
replay_start_index: int | None = None,
include_joint_observations: bool = False,
**_: Any,
) -> TrenchesOpenEnvObservation:
self._training_agent = self._validate_training_agent(training_agent)
self._include_joint_observations = include_joint_observations
self._last_oversight = OversightIntervention()
self._session = self._env.create_session(
seed=seed,
session_id=episode_id,
training_agent=self._training_agent,
training_stage=training_stage,
max_turns=max_turns,
scenario_id=scenario_id,
replay_id=replay_id,
replay_start_index=replay_start_index,
)
return self._build_observation(self._session)
def step(self, action: TrenchesOpenEnvAction, **_: Any) -> TrenchesOpenEnvObservation:
if self._session is None:
raise RuntimeError("reset() must be called before step().")
joint_actions = self._resolve_joint_actions(action)
result = self._env.step_session(
self._session,
StepSessionRequest(
actions=joint_actions,
predictions=self._resolve_joint_predictions(action),
external_signals=action.external_signals,
),
)
self._session = result.session
self._last_oversight = result.oversight
return self._build_observation(result.session)
def close(self) -> None:
self._env.shutdown()
def _build_observation(self, session: SessionState) -> TrenchesOpenEnvObservation:
reward_breakdown = session.rewards.get(self._training_agent, RewardBreakdown())
done_reason: str | None = None
if session.world.tension_level >= 95.0:
done_reason = "tension_threshold"
elif session.world.turn >= session.episode.max_turns:
done_reason = "max_turns"
elif (
session.historical_replay.enabled
and session.historical_replay.current_event_index >= len(session.historical_replay.ground_truth_timeline) - 1
):
done_reason = "replay_complete"
recent_trace = session.recent_traces[-1] if session.recent_traces else None
return TrenchesOpenEnvObservation(
session_id=session.session_id,
training_agent=self._training_agent,
turn=session.world.turn,
training_stage=session.episode.training_stage,
agent_observation=session.observations[self._training_agent],
joint_observations=session.observations if self._include_joint_observations else {},
reward=reward_breakdown.total,
done=done_reason is not None,
reward_breakdown=reward_breakdown,
oversight=self._last_oversight,
historical_replay=self._public_historical_replay(session.historical_replay),
revealed_event=recent_trace.revealed_event if recent_trace is not None else session.historical_replay.last_revealed_event,
prediction_assessments=recent_trace.prediction_assessments if recent_trace is not None else {},
done_reason=done_reason,
)
def _resolve_joint_actions(self, action: TrenchesOpenEnvAction) -> dict[str, AgentAction]:
joint_actions = dict(action.actions)
if action.action is not None:
joint_actions[action.action.actor] = action.action
unknown_agents = sorted(set(joint_actions) - set(AGENT_IDS))
if unknown_agents:
raise ValueError(f"Unknown agents in joint action: {', '.join(unknown_agents)}")
missing_agents = [agent_id for agent_id in AGENT_IDS if agent_id not in joint_actions]
if missing_agents and action.autofill_missing_with_policy and self._session is not None:
joint_actions = self._env.resolve_policy_actions(
self._session,
action.external_signals,
preset_actions=joint_actions,
agent_ids=missing_agents,
)
elif action.autofill_missing_with_hold:
for agent_id in missing_agents:
joint_actions[agent_id] = AgentAction(
actor=agent_id,
type="hold",
summary=f"Auto-filled hold for {agent_id} in the OpenEnv step.",
)
return joint_actions
def _resolve_joint_predictions(self, action: TrenchesOpenEnvAction) -> dict[str, Prediction]:
joint_predictions = dict(action.predictions)
if action.prediction is not None:
joint_predictions[action.prediction.agent_id] = action.prediction
return joint_predictions
@staticmethod
def _validate_training_agent(training_agent: str) -> str:
if training_agent not in AGENT_IDS:
raise ValueError(f"Unknown training_agent: {training_agent}")
return training_agent
@staticmethod
def _public_historical_replay(historical_replay: HistoricalReplayState) -> HistoricalReplayState:
if not historical_replay.enabled:
return historical_replay.model_copy(deep=True)
return historical_replay.model_copy(
update={
"ground_truth_timeline": [],
},
deep=True,
)
def create_openenv_fastapi_app(
env_factory: Callable[[], TrenchesOpenEnvEnvironment] | None = None,
) -> FastAPI | None:
if not OPENENV_CORE_AVAILABLE:
return None
factory = env_factory or TrenchesOpenEnvEnvironment
return create_openenv_app(
factory,
TrenchesOpenEnvAction,
TrenchesOpenEnvObservation,
env_name="trenches",
)
class OpenEnvAdapter:
"""Legacy tuple-based runtime kept for the existing frontend/session endpoints."""
def __init__(self, session_manager: SessionManager | None = None) -> None:
self.session_manager = session_manager or SessionManager()
self._current_session_id: str | None = None
def reset(
self,
seed: int | None = None,
training_agent: str = "us",
training_stage: TrainingStage = DEFAULT_TRAINING_STAGE,
max_turns: int | None = None,
scenario_id: str | None = None,
replay_id: str | None = None,
replay_start_index: int | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
session = self.session_manager.create_session(
seed=seed,
training_agent=training_agent,
training_stage=training_stage,
max_turns=max_turns,
scenario_id=scenario_id,
replay_id=replay_id,
replay_start_index=replay_start_index,
)
self._current_session_id = session.session_id
return session.observations, self._build_info(session=session, oversight=OversightIntervention())
def step(
self,
actions: dict[str, Any],
predictions: dict[str, Prediction] | None = None,
external_signals: list[ExternalSignal] | None = None,
) -> tuple[dict[str, Any], dict[str, Any], bool, bool, dict[str, Any]]:
if self._current_session_id is None:
self.reset()
result = self.session_manager.step_session(
self._current_session_id,
StepSessionRequest(
actions=actions,
predictions=predictions or {},
external_signals=external_signals or [],
),
)
session = result.session
terminated = session.world.tension_level >= 95.0
truncated = session.world.turn >= session.episode.max_turns or (
session.historical_replay.enabled
and session.historical_replay.current_event_index >= len(session.historical_replay.ground_truth_timeline) - 1
)
return (
session.observations,
session.rewards,
terminated,
truncated,
self._build_info(session=session, oversight=result.oversight),
)
def state(self) -> SessionState | None:
if self._current_session_id is None:
return None
return self.session_manager.get_session(self._current_session_id)
@staticmethod
def _build_info(session: SessionState, oversight: OversightIntervention) -> dict[str, Any]:
return {
"session_id": session.session_id,
"turn": session.world.turn,
"agent_ids": list(AGENT_IDS),
"world": session.world.model_dump(mode="json"),
"belief_state": {agent_id: state.model_dump(mode="json") for agent_id, state in session.belief_state.items()},
"episode": session.episode.model_dump(mode="json"),
"oversight": oversight.model_dump(mode="json"),
"live": session.live.model_dump(mode="json"),
"historical_replay": session.historical_replay.model_dump(mode="json"),
"prediction_assessments": [assessment.model_dump(mode="json") for assessment in session.prediction_assessments[-5:]],
}