Spaces:
Paused
Paused
| from __future__ import annotations | |
| from contextlib import asynccontextmanager | |
| import os | |
| from typing import Any | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from trenches_env.env import FogOfWarDiplomacyEnv | |
| from trenches_env.model_runtime import build_entity_model_bindings | |
| from trenches_env.models import ( | |
| BenchmarkRunRequest, | |
| BenchmarkRunResponse, | |
| CreateSessionRequest, | |
| IngestNewsRequest, | |
| IngestNewsResponse, | |
| LiveControlRequest, | |
| ProviderDiagnosticsResponse, | |
| ReactionLogEntry, | |
| ResetEnvRequest, | |
| ResetEnvResponse, | |
| ResetSessionRequest, | |
| ScenarioSummary, | |
| SessionState, | |
| SourceMonitorReport, | |
| StepEnvRequest, | |
| StepEnvResponse, | |
| StepSessionRequest, | |
| StepSessionResponse, | |
| ) | |
| from trenches_env.openenv_adapter import ( | |
| OPENENV_CORE_AVAILABLE, | |
| OpenEnvAdapter, | |
| TrenchesOpenEnvEnvironment, | |
| create_openenv_fastapi_app, | |
| ) | |
| from trenches_env.session_manager import SessionManager | |
| from trenches_env.source_ingestion import SourceHarvester | |
| DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX = r"https?://(localhost|127\.0\.0\.1)(:\d+)?$" | |
| def _parse_csv_env(raw_value: str | None) -> list[str]: | |
| if not raw_value: | |
| return [] | |
| return [item.strip() for item in raw_value.split(",") if item.strip()] | |
| def _resolve_cors_settings() -> dict[str, Any]: | |
| allow_origins = _parse_csv_env(os.getenv("TRENCHES_CORS_ALLOW_ORIGINS")) | |
| allow_origin_regex = os.getenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX") | |
| if "*" in allow_origins: | |
| return { | |
| "allow_origins": ["*"], | |
| "allow_origin_regex": None, | |
| # Browsers reject wildcard origins when credentials are enabled. | |
| "allow_credentials": False, | |
| "allow_methods": ["*"], | |
| "allow_headers": ["*"], | |
| } | |
| if not allow_origins and not allow_origin_regex: | |
| allow_origin_regex = DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX | |
| allow_credentials = os.getenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "true").strip().lower() not in { | |
| "0", | |
| "false", | |
| "no", | |
| "off", | |
| } | |
| return { | |
| "allow_origins": allow_origins, | |
| "allow_origin_regex": allow_origin_regex, | |
| "allow_credentials": allow_credentials, | |
| "allow_methods": ["*"], | |
| "allow_headers": ["*"], | |
| } | |
| def create_app(session_manager: SessionManager | None = None) -> FastAPI: | |
| manager = session_manager or SessionManager( | |
| env=FogOfWarDiplomacyEnv( | |
| source_harvester=SourceHarvester(auto_start=True), | |
| ).enable_source_warm_start() | |
| ) | |
| async def lifespan(_: FastAPI): | |
| try: | |
| manager.start_background_runner() | |
| yield | |
| finally: | |
| manager.shutdown() | |
| app = FastAPI(title="Trenches OpenEnv Backend", version="0.1.0", lifespan=lifespan) | |
| app.add_middleware(CORSMiddleware, **_resolve_cors_settings()) | |
| runtime = OpenEnvAdapter(session_manager=manager) | |
| openenv_app = create_openenv_fastapi_app( | |
| lambda: TrenchesOpenEnvEnvironment( | |
| env=FogOfWarDiplomacyEnv( | |
| source_harvester=SourceHarvester(auto_start=False), | |
| ).enable_source_warm_start() | |
| ) | |
| ) | |
| if openenv_app is not None: | |
| app.mount("/openenv", openenv_app) | |
| async def healthz() -> dict[str, str]: | |
| return {"status": "ok"} | |
| async def capabilities() -> dict[str, Any]: | |
| cors_settings = _resolve_cors_settings() | |
| return { | |
| "model_bindings": { | |
| agent_id: binding.model_dump(mode="json") | |
| for agent_id, binding in build_entity_model_bindings().items() | |
| }, | |
| "session_api": True, | |
| "legacy_openenv_tuple_api": True, | |
| "native_openenv_api": OPENENV_CORE_AVAILABLE, | |
| "native_openenv_base_path": "/openenv" if OPENENV_CORE_AVAILABLE else None, | |
| "cors": { | |
| "allow_origins": cors_settings["allow_origins"], | |
| "allow_origin_regex": cors_settings["allow_origin_regex"], | |
| "allow_credentials": cors_settings["allow_credentials"], | |
| }, | |
| } | |
| async def create_session(request: CreateSessionRequest) -> SessionState: | |
| return manager.create_session( | |
| seed=request.seed, | |
| training_agent=request.training_agent, | |
| training_stage=request.training_stage, | |
| max_turns=request.max_turns, | |
| scenario_id=request.scenario_id, | |
| replay_id=request.replay_id, | |
| replay_start_index=request.replay_start_index, | |
| ) | |
| async def reset_session(session_id: str, request: ResetSessionRequest) -> SessionState: | |
| try: | |
| return manager.reset_session( | |
| session_id=session_id, | |
| seed=request.seed, | |
| training_agent=request.training_agent, | |
| training_stage=request.training_stage, | |
| max_turns=request.max_turns, | |
| scenario_id=request.scenario_id, | |
| replay_id=request.replay_id, | |
| replay_start_index=request.replay_start_index, | |
| ) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| async def list_scenarios() -> list[ScenarioSummary]: | |
| return manager.list_scenarios() | |
| async def run_benchmark(request: BenchmarkRunRequest) -> BenchmarkRunResponse: | |
| try: | |
| return manager.run_benchmark(request) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| async def get_session(session_id: str) -> SessionState: | |
| try: | |
| return manager.get_session(session_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| async def refresh_session_sources(session_id: str) -> SessionState: | |
| try: | |
| return manager.refresh_session_sources(session_id=session_id, force=True) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| async def source_monitor(session_id: str) -> SourceMonitorReport: | |
| try: | |
| return manager.source_monitor(session_id=session_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| async def reaction_log(session_id: str) -> list[ReactionLogEntry]: | |
| try: | |
| return manager.reaction_log(session_id=session_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| async def provider_diagnostics(session_id: str) -> ProviderDiagnosticsResponse: | |
| try: | |
| return manager.provider_diagnostics(session_id=session_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| async def ingest_news(session_id: str, request: IngestNewsRequest) -> IngestNewsResponse: | |
| try: | |
| return manager.ingest_news(session_id=session_id, request=request) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| async def set_live_mode(session_id: str, request: LiveControlRequest) -> SessionState: | |
| try: | |
| return manager.set_live_mode(session_id=session_id, request=request) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| async def step_session(session_id: str, request: StepSessionRequest) -> StepSessionResponse: | |
| try: | |
| return manager.step_session(session_id=session_id, request=request) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| async def reset_env(request: ResetEnvRequest) -> ResetEnvResponse: | |
| observations, info = runtime.reset( | |
| seed=request.seed, | |
| training_stage=request.training_stage, | |
| max_turns=request.max_turns, | |
| scenario_id=request.scenario_id, | |
| replay_id=request.replay_id, | |
| replay_start_index=request.replay_start_index, | |
| ) | |
| return ResetEnvResponse(observations=observations, info=info) | |
| async def step_env(request: StepEnvRequest) -> StepEnvResponse: | |
| try: | |
| observations, rewards, terminated, truncated, info = runtime.step( | |
| actions=request.actions, | |
| predictions=request.predictions, | |
| external_signals=request.external_signals, | |
| ) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| return StepEnvResponse( | |
| observations=observations, | |
| rewards=rewards, | |
| terminated=terminated, | |
| truncated=truncated, | |
| info=info, | |
| ) | |
| async def state_env() -> SessionState: | |
| session = runtime.state() | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="No active OpenEnv runtime session.") | |
| return session | |
| return app | |
| app = create_app() | |
| def run() -> None: | |
| uvicorn.run("trenches_env.server:app", host="0.0.0.0", port=8000, reload=False) | |
| if __name__ == "__main__": | |
| run() | |