Codex
sync main snapshot for HF Space
1794757
from __future__ import annotations
from contextlib import asynccontextmanager
import os
from typing import Any
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.model_runtime import build_entity_model_bindings
from trenches_env.models import (
BenchmarkRunRequest,
BenchmarkRunResponse,
CreateSessionRequest,
IngestNewsRequest,
IngestNewsResponse,
LiveControlRequest,
ProviderDiagnosticsResponse,
ReactionLogEntry,
ResetEnvRequest,
ResetEnvResponse,
ResetSessionRequest,
ScenarioSummary,
SessionState,
SourceMonitorReport,
StepEnvRequest,
StepEnvResponse,
StepSessionRequest,
StepSessionResponse,
)
from trenches_env.openenv_adapter import (
OPENENV_CORE_AVAILABLE,
OpenEnvAdapter,
TrenchesOpenEnvEnvironment,
create_openenv_fastapi_app,
)
from trenches_env.session_manager import SessionManager
from trenches_env.source_ingestion import SourceHarvester
DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX = r"https?://(localhost|127\.0\.0\.1)(:\d+)?$"
def _parse_csv_env(raw_value: str | None) -> list[str]:
if not raw_value:
return []
return [item.strip() for item in raw_value.split(",") if item.strip()]
def _resolve_cors_settings() -> dict[str, Any]:
allow_origins = _parse_csv_env(os.getenv("TRENCHES_CORS_ALLOW_ORIGINS"))
allow_origin_regex = os.getenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX")
if "*" in allow_origins:
return {
"allow_origins": ["*"],
"allow_origin_regex": None,
# Browsers reject wildcard origins when credentials are enabled.
"allow_credentials": False,
"allow_methods": ["*"],
"allow_headers": ["*"],
}
if not allow_origins and not allow_origin_regex:
allow_origin_regex = DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX
allow_credentials = os.getenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "true").strip().lower() not in {
"0",
"false",
"no",
"off",
}
return {
"allow_origins": allow_origins,
"allow_origin_regex": allow_origin_regex,
"allow_credentials": allow_credentials,
"allow_methods": ["*"],
"allow_headers": ["*"],
}
def create_app(session_manager: SessionManager | None = None) -> FastAPI:
manager = session_manager or SessionManager(
env=FogOfWarDiplomacyEnv(
source_harvester=SourceHarvester(auto_start=True),
).enable_source_warm_start()
)
@asynccontextmanager
async def lifespan(_: FastAPI):
try:
manager.start_background_runner()
yield
finally:
manager.shutdown()
app = FastAPI(title="Trenches OpenEnv Backend", version="0.1.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, **_resolve_cors_settings())
runtime = OpenEnvAdapter(session_manager=manager)
openenv_app = create_openenv_fastapi_app(
lambda: TrenchesOpenEnvEnvironment(
env=FogOfWarDiplomacyEnv(
source_harvester=SourceHarvester(auto_start=False),
).enable_source_warm_start()
)
)
if openenv_app is not None:
app.mount("/openenv", openenv_app)
@app.get("/healthz")
async def healthz() -> dict[str, str]:
return {"status": "ok"}
@app.get("/capabilities")
async def capabilities() -> dict[str, Any]:
cors_settings = _resolve_cors_settings()
return {
"model_bindings": {
agent_id: binding.model_dump(mode="json")
for agent_id, binding in build_entity_model_bindings().items()
},
"session_api": True,
"legacy_openenv_tuple_api": True,
"native_openenv_api": OPENENV_CORE_AVAILABLE,
"native_openenv_base_path": "/openenv" if OPENENV_CORE_AVAILABLE else None,
"cors": {
"allow_origins": cors_settings["allow_origins"],
"allow_origin_regex": cors_settings["allow_origin_regex"],
"allow_credentials": cors_settings["allow_credentials"],
},
}
@app.post("/sessions", response_model=SessionState)
async def create_session(request: CreateSessionRequest) -> SessionState:
return manager.create_session(
seed=request.seed,
training_agent=request.training_agent,
training_stage=request.training_stage,
max_turns=request.max_turns,
scenario_id=request.scenario_id,
replay_id=request.replay_id,
replay_start_index=request.replay_start_index,
)
@app.post("/sessions/{session_id}/reset", response_model=SessionState)
async def reset_session(session_id: str, request: ResetSessionRequest) -> SessionState:
try:
return manager.reset_session(
session_id=session_id,
seed=request.seed,
training_agent=request.training_agent,
training_stage=request.training_stage,
max_turns=request.max_turns,
scenario_id=request.scenario_id,
replay_id=request.replay_id,
replay_start_index=request.replay_start_index,
)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/scenarios", response_model=list[ScenarioSummary])
async def list_scenarios() -> list[ScenarioSummary]:
return manager.list_scenarios()
@app.post("/benchmarks/run", response_model=BenchmarkRunResponse)
async def run_benchmark(request: BenchmarkRunRequest) -> BenchmarkRunResponse:
try:
return manager.run_benchmark(request)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/sessions/{session_id}", response_model=SessionState)
async def get_session(session_id: str) -> SessionState:
try:
return manager.get_session(session_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
@app.post("/sessions/{session_id}/sources/refresh", response_model=SessionState)
async def refresh_session_sources(session_id: str) -> SessionState:
try:
return manager.refresh_session_sources(session_id=session_id, force=True)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
@app.get("/sessions/{session_id}/sources/monitor", response_model=SourceMonitorReport)
async def source_monitor(session_id: str) -> SourceMonitorReport:
try:
return manager.source_monitor(session_id=session_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
@app.get("/sessions/{session_id}/reactions", response_model=list[ReactionLogEntry])
async def reaction_log(session_id: str) -> list[ReactionLogEntry]:
try:
return manager.reaction_log(session_id=session_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
@app.get("/sessions/{session_id}/providers/diagnostics", response_model=ProviderDiagnosticsResponse)
async def provider_diagnostics(session_id: str) -> ProviderDiagnosticsResponse:
try:
return manager.provider_diagnostics(session_id=session_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
@app.post("/sessions/{session_id}/news", response_model=IngestNewsResponse)
async def ingest_news(session_id: str, request: IngestNewsRequest) -> IngestNewsResponse:
try:
return manager.ingest_news(session_id=session_id, request=request)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/sessions/{session_id}/live", response_model=SessionState)
async def set_live_mode(session_id: str, request: LiveControlRequest) -> SessionState:
try:
return manager.set_live_mode(session_id=session_id, request=request)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/sessions/{session_id}/step", response_model=StepSessionResponse)
async def step_session(session_id: str, request: StepSessionRequest) -> StepSessionResponse:
try:
return manager.step_session(session_id=session_id, request=request)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/reset", response_model=ResetEnvResponse)
async def reset_env(request: ResetEnvRequest) -> ResetEnvResponse:
observations, info = runtime.reset(
seed=request.seed,
training_stage=request.training_stage,
max_turns=request.max_turns,
scenario_id=request.scenario_id,
replay_id=request.replay_id,
replay_start_index=request.replay_start_index,
)
return ResetEnvResponse(observations=observations, info=info)
@app.post("/step", response_model=StepEnvResponse)
async def step_env(request: StepEnvRequest) -> StepEnvResponse:
try:
observations, rewards, terminated, truncated, info = runtime.step(
actions=request.actions,
predictions=request.predictions,
external_signals=request.external_signals,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return StepEnvResponse(
observations=observations,
rewards=rewards,
terminated=terminated,
truncated=truncated,
info=info,
)
@app.get("/state", response_model=SessionState)
async def state_env() -> SessionState:
session = runtime.state()
if session is None:
raise HTTPException(status_code=404, detail="No active OpenEnv runtime session.")
return session
return app
app = create_app()
def run() -> None:
uvicorn.run("trenches_env.server:app", host="0.0.0.0", port=8000, reload=False)
if __name__ == "__main__":
run()