Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from openenv_app.openenv_wrapper import OpenEnvTrafficWrapper | |
| from openenv_app.replay_runner import get_cached, run_and_cache | |
| from openenv_app.schema import ( | |
| HealthResponse, | |
| ReplayResponse, | |
| ResetRequest, | |
| ResetResponse, | |
| StateResponse, | |
| StepRequest, | |
| StepResponse, | |
| ) | |
| from server.path_validators import validate_path_segment | |
| logger = logging.getLogger(__name__) | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DATA_DIR = Path(os.environ.get("DATA_DIR", "") or (_REPO_ROOT / "data" / "generated")) | |
| SPLITS_DIR = Path(os.environ.get("SPLITS_DIR", "") or (_REPO_ROOT / "data" / "splits")) | |
| CHECKPOINT_PATH = Path( | |
| os.environ.get("CHECKPOINT_PATH", "") | |
| or (_REPO_ROOT / "artifacts" / "dqn_shared" / "best_validation.pt") | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Startup / lifespan | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| """Load the DQN checkpoint once at startup so replay requests are fast.""" | |
| if CHECKPOINT_PATH.exists(): | |
| from server.policy_runner import load_dqn_checkpoint | |
| load_dqn_checkpoint(CHECKPOINT_PATH) | |
| else: | |
| logger.warning("Checkpoint not found at %s — 'learned' policy will fail", CHECKPOINT_PATH) | |
| yield | |
| # --------------------------------------------------------------------------- | |
| # App | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="DistrictFlow OpenEnv App", | |
| description="OpenEnv-style traffic environment for district-level LLM coordination.", | |
| version="0.1.0", | |
| lifespan=lifespan, | |
| ) | |
| # Lazy-initialized: only constructed when /reset or /step is first called. | |
| _wrapper: OpenEnvTrafficWrapper | None = None | |
| def _get_wrapper() -> OpenEnvTrafficWrapper: | |
| global _wrapper | |
| if _wrapper is None: | |
| _wrapper = OpenEnvTrafficWrapper( | |
| generated_root=DATA_DIR, | |
| splits_root=SPLITS_DIR, | |
| ) | |
| return _wrapper | |
| # --------------------------------------------------------------------------- | |
| # Health | |
| # --------------------------------------------------------------------------- | |
| def root(): | |
| return HealthResponse(status="ok", message="DistrictFlow OpenEnv app is running.") | |
| def health(): | |
| return HealthResponse(status="ok", message="healthy") | |
| # --------------------------------------------------------------------------- | |
| # Step / Reset / State | |
| # --------------------------------------------------------------------------- | |
| def reset(request: ResetRequest): | |
| payload = _get_wrapper().reset( | |
| seed=request.seed, | |
| city_id=request.city_id, | |
| scenario_name=request.scenario_name, | |
| ) | |
| return ResetResponse(observation=payload["observation"], info=payload.get("info", {})) | |
| def step(request: StepRequest): | |
| payload = _get_wrapper().step(action=request.action) | |
| return StepResponse( | |
| observation=payload["observation"], | |
| reward=payload["reward"], | |
| done=payload["done"], | |
| truncated=payload.get("truncated", False), | |
| info=payload.get("info", {}), | |
| ) | |
| def state(): | |
| payload = _get_wrapper().state() | |
| return StateResponse(state=payload["state"]) | |
| # --------------------------------------------------------------------------- | |
| # Replay (on-demand simulation + in-memory cache) | |
| # --------------------------------------------------------------------------- | |
| _VALID_POLICIES = {"no_intervention", "fixed", "random", "learned"} | |
| def get_replay(city_id: str, scenario_name: str, policy_name: str) -> ReplayResponse: | |
| """Run a full simulation and return the CityFlow replay + metrics. | |
| Results are cached in memory so repeated calls are instant. | |
| """ | |
| validate_path_segment(city_id, "city_id") | |
| validate_path_segment(scenario_name, "scenario_name") | |
| if policy_name not in _VALID_POLICIES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown policy '{policy_name}'. Valid: {sorted(_VALID_POLICIES)}", | |
| ) | |
| cached = get_cached(city_id, scenario_name, policy_name) | |
| if cached is None: | |
| try: | |
| cached = run_and_cache( | |
| city_id=city_id, | |
| scenario_name=scenario_name, | |
| policy_name=policy_name, | |
| generated_root=DATA_DIR, | |
| ) | |
| except FileNotFoundError as exc: | |
| logger.error("Replay file missing after simulation: %s", exc) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Simulation completed but no replay file was produced.", | |
| ) from exc | |
| except Exception as exc: | |
| logger.error("Simulation failed for %s/%s/%s: %s", city_id, scenario_name, policy_name, exc) | |
| raise HTTPException(status_code=500, detail="Simulation failed.") from exc | |
| replay_text, roadnet_log, metrics = cached | |
| return ReplayResponse( | |
| city_id=city_id, | |
| scenario_name=scenario_name, | |
| policy_name=policy_name, | |
| replay_text=replay_text, | |
| roadnet_log=roadnet_log, | |
| metrics=metrics, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Error handler | |
| # --------------------------------------------------------------------------- | |
| def unhandled_exception_handler(request, exc): | |
| logger.error("Unhandled exception: %s: %s", type(exc).__name__, exc) | |
| return JSONResponse(status_code=500, content={"error": "Internal server error"}) | |