| from __future__ import annotations |
|
|
| import logging |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
| import os |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import JSONResponse |
|
|
| from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper |
| from openenv_app.replay_runner import get_cached, run_and_cache |
| from openenv_app.schema import ( |
| HealthResponse, |
| ReplayResponse, |
| ResetRequest, |
| ResetResponse, |
| StateResponse, |
| StepRequest, |
| StepResponse, |
| ) |
| from server.path_validators import validate_path_segment |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[1] |
| DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (_REPO_ROOT / "data" / "generated")) |
| SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (_REPO_ROOT / "data" / "splits")) |
| CHECKPOINT_PATH = Path( |
| os.environ.get("CHECKPOINT_PATH", "") |
| or (_REPO_ROOT / "artifacts" / "dqn_shared" / "best_validation.pt") |
| ) |
|
|
| |
| |
| |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Load the DQN checkpoint once at startup so replay requests are fast.""" |
| if CHECKPOINT_PATH.exists(): |
| from server.policy_runner import load_dqn_checkpoint |
| load_dqn_checkpoint(CHECKPOINT_PATH) |
| else: |
| logger.warning("Checkpoint not found at %s — 'learned' policy will fail", CHECKPOINT_PATH) |
| yield |
|
|
|
|
| |
| |
| |
|
|
| app = FastAPI( |
| title="DistrictFlow OpenEnv App", |
| description="OpenEnv-style traffic environment for district-level LLM coordination.", |
| version="0.1.0", |
| lifespan=lifespan, |
| ) |
|
|
| |
| _wrapper: OpenEnvTrafficWrapper | None = None |
|
|
|
|
| def _get_wrapper() -> OpenEnvTrafficWrapper: |
| global _wrapper |
| if _wrapper is None: |
| _wrapper = OpenEnvTrafficWrapper( |
| generated_root=DATA_DIR, |
| splits_root=SPLITS_DIR, |
| ) |
| return _wrapper |
|
|
|
|
| |
| |
| |
|
|
|
|
| @app.get("/", response_model=HealthResponse) |
| def root(): |
| return HealthResponse(status="ok", message="DistrictFlow OpenEnv app is running.") |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health(): |
| return HealthResponse(status="ok", message="healthy") |
|
|
|
|
| |
| |
| |
|
|
|
|
| @app.post("/reset", response_model=ResetResponse) |
| def reset(request: ResetRequest): |
| payload = _get_wrapper().reset( |
| seed=request.seed, |
| city_id=request.city_id, |
| scenario_name=request.scenario_name, |
| ) |
| return ResetResponse(observation=payload["observation"], info=payload.get("info", {})) |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(request: StepRequest): |
| payload = _get_wrapper().step(action=request.action) |
| return StepResponse( |
| observation=payload["observation"], |
| reward=payload["reward"], |
| done=payload["done"], |
| truncated=payload.get("truncated", False), |
| info=payload.get("info", {}), |
| ) |
|
|
|
|
| @app.get("/state", response_model=StateResponse) |
| def state(): |
| payload = _get_wrapper().state() |
| return StateResponse(state=payload["state"]) |
|
|
|
|
| |
| |
| |
|
|
| _VALID_POLICIES = {"no_intervention", "fixed", "random", "learned"} |
|
|
|
|
| @app.get("/replay/{city_id}/{scenario_name}/{policy_name}", response_model=ReplayResponse) |
| def get_replay(city_id: str, scenario_name: str, policy_name: str) -> ReplayResponse: |
| """Run a full simulation and return the CityFlow replay + metrics. |
| |
| Results are cached in memory so repeated calls are instant. |
| """ |
| validate_path_segment(city_id, "city_id") |
| validate_path_segment(scenario_name, "scenario_name") |
|
|
| if policy_name not in _VALID_POLICIES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown policy '{policy_name}'. Valid: {sorted(_VALID_POLICIES)}", |
| ) |
|
|
| cached = get_cached(city_id, scenario_name, policy_name) |
| if cached is None: |
| try: |
| cached = run_and_cache( |
| city_id=city_id, |
| scenario_name=scenario_name, |
| policy_name=policy_name, |
| generated_root=DATA_DIR, |
| ) |
| except FileNotFoundError as exc: |
| logger.error("Replay file missing after simulation: %s", exc) |
| raise HTTPException( |
| status_code=500, |
| detail="Simulation completed but no replay file was produced.", |
| ) from exc |
| except Exception as exc: |
| logger.error("Simulation failed for %s/%s/%s: %s", city_id, scenario_name, policy_name, exc) |
| raise HTTPException(status_code=500, detail="Simulation failed.") from exc |
|
|
| replay_text, roadnet_log, metrics = cached |
| return ReplayResponse( |
| city_id=city_id, |
| scenario_name=scenario_name, |
| policy_name=policy_name, |
| replay_text=replay_text, |
| roadnet_log=roadnet_log, |
| metrics=metrics, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @app.exception_handler(Exception) |
| def unhandled_exception_handler(request, exc): |
| logger.error("Unhandled exception: %s: %s", type(exc).__name__, exc) |
| return JSONResponse(status_code=500, content={"error": "Internal server error"}) |
|
|