""" OpenEnv-compatible FastAPI Server Endpoints: GET /reset POST /step GET /lidar POST /negotiate GET /state Deploy: uvicorn server.server:app --host 0.0.0.0 --port 7860 """ import sys, os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from env.negotiation_env import NegotiationDrivingEnv app = FastAPI(title="Autonomous Driving OpenEnv", version="0.2.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) env = NegotiationDrivingEnv() obs, _ = env.reset() # ── Request / Response models ───────────────────────────────────────── class StepRequest(BaseModel): action: int # 0=accelerate 1=brake 2=lane_left 3=lane_right reasoning: str = "" # LLM chain-of-thought (stored for logging) class NegotiateRequest(BaseModel): target: str # "blocker" or "traffic" message: str # ── Endpoints ───────────────────────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok", "version": "0.2.1"} @app.get("/reset") def reset(): global obs obs, info = env.reset() return { "observation": obs.tolist(), "render": env.render(), "lidar": env.lidar_scan(), "info": info, } @app.post("/step") def step(req: StepRequest): global obs if req.action not in [0, 1, 2, 3]: raise HTTPException(400, "action must be 0-3") obs, reward, done, truncated, info = env.step(req.action) return { "observation": obs.tolist(), "reward": round(reward, 4), "done": done, "truncated": truncated, "info": info, "render": env.render(), "lidar": env.lidar_scan(), "collision": env.predict_collision(), "negotiation_log": env.negotiation_log[-5:], } @app.get("/lidar") def lidar(): return { **env.lidar_scan(), **env.predict_collision(), } @app.post("/negotiate") def negotiate(req: NegotiateRequest): response = env.negotiate(req.target, req.message) return { "response": response, "negotiation_log": env.negotiation_log, "blocker_yielding": env._blocker_yielding, } @app.get("/state") def state(): return { "ego": env.ego.tolist(), "blocker": env.blocker.tolist(), "traffic": env.traffic.tolist(), "step": env.step_count, "render": env.render(), "memory": env.memory[-5:], "neg_log": env.negotiation_log, } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)