tokev's picture
Add files using upload-large-folder tool
5893134 verified
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper
from openenv_app.replay_runner import get_cached, run_and_cache
from openenv_app.schema import (
HealthResponse,
ReplayResponse,
ResetRequest,
ResetResponse,
StateResponse,
StepRequest,
StepResponse,
)
from server.path_validators import validate_path_segment
logger = logging.getLogger(__name__)
_REPO_ROOT = Path(__file__).resolve().parents[1]
DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (_REPO_ROOT / "data" / "generated"))
SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (_REPO_ROOT / "data" / "splits"))
CHECKPOINT_PATH = Path(
os.environ.get("CHECKPOINT_PATH", "")
or (_REPO_ROOT / "artifacts" / "dqn_shared" / "best_validation.pt")
)
# ---------------------------------------------------------------------------
# Startup / lifespan
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load the DQN checkpoint once at startup so replay requests are fast."""
if CHECKPOINT_PATH.exists():
from server.policy_runner import load_dqn_checkpoint
load_dqn_checkpoint(CHECKPOINT_PATH)
else:
logger.warning("Checkpoint not found at %s — 'learned' policy will fail", CHECKPOINT_PATH)
yield
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
title="DistrictFlow OpenEnv App",
description="OpenEnv-style traffic environment for district-level LLM coordination.",
version="0.1.0",
lifespan=lifespan,
)
# Lazy-initialized: only constructed when /reset or /step is first called.
_wrapper: OpenEnvTrafficWrapper | None = None
def _get_wrapper() -> OpenEnvTrafficWrapper:
global _wrapper
if _wrapper is None:
_wrapper = OpenEnvTrafficWrapper(
generated_root=DATA_DIR,
splits_root=SPLITS_DIR,
)
return _wrapper
# ---------------------------------------------------------------------------
# Health
# ---------------------------------------------------------------------------
@app.get("/", response_model=HealthResponse)
def root():
return HealthResponse(status="ok", message="DistrictFlow OpenEnv app is running.")
@app.get("/health", response_model=HealthResponse)
def health():
return HealthResponse(status="ok", message="healthy")
# ---------------------------------------------------------------------------
# Step / Reset / State
# ---------------------------------------------------------------------------
@app.post("/reset", response_model=ResetResponse)
def reset(request: ResetRequest):
payload = _get_wrapper().reset(
seed=request.seed,
city_id=request.city_id,
scenario_name=request.scenario_name,
)
return ResetResponse(observation=payload["observation"], info=payload.get("info", {}))
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest):
payload = _get_wrapper().step(action=request.action)
return StepResponse(
observation=payload["observation"],
reward=payload["reward"],
done=payload["done"],
truncated=payload.get("truncated", False),
info=payload.get("info", {}),
)
@app.get("/state", response_model=StateResponse)
def state():
payload = _get_wrapper().state()
return StateResponse(state=payload["state"])
# ---------------------------------------------------------------------------
# Replay (on-demand simulation + in-memory cache)
# ---------------------------------------------------------------------------
_VALID_POLICIES = {"no_intervention", "fixed", "random", "learned"}
@app.get("/replay/{city_id}/{scenario_name}/{policy_name}", response_model=ReplayResponse)
def get_replay(city_id: str, scenario_name: str, policy_name: str) -> ReplayResponse:
"""Run a full simulation and return the CityFlow replay + metrics.
Results are cached in memory so repeated calls are instant.
"""
validate_path_segment(city_id, "city_id")
validate_path_segment(scenario_name, "scenario_name")
if policy_name not in _VALID_POLICIES:
raise HTTPException(
status_code=400,
detail=f"Unknown policy '{policy_name}'. Valid: {sorted(_VALID_POLICIES)}",
)
cached = get_cached(city_id, scenario_name, policy_name)
if cached is None:
try:
cached = run_and_cache(
city_id=city_id,
scenario_name=scenario_name,
policy_name=policy_name,
generated_root=DATA_DIR,
)
except FileNotFoundError as exc:
logger.error("Replay file missing after simulation: %s", exc)
raise HTTPException(
status_code=500,
detail="Simulation completed but no replay file was produced.",
) from exc
except Exception as exc:
logger.error("Simulation failed for %s/%s/%s: %s", city_id, scenario_name, policy_name, exc)
raise HTTPException(status_code=500, detail="Simulation failed.") from exc
replay_text, roadnet_log, metrics = cached
return ReplayResponse(
city_id=city_id,
scenario_name=scenario_name,
policy_name=policy_name,
replay_text=replay_text,
roadnet_log=roadnet_log,
metrics=metrics,
)
# ---------------------------------------------------------------------------
# Error handler
# ---------------------------------------------------------------------------
@app.exception_handler(Exception)
def unhandled_exception_handler(request, exc):
logger.error("Unhandled exception: %s: %s", type(exc).__name__, exc)
return JSONResponse(status_code=500, content={"error": "Internal server error"})