AD / server /server.py
helshahaby's picture
Upload 13 files
3d58f38 verified
"""
OpenEnv-compatible FastAPI Server
Endpoints: GET /reset POST /step GET /lidar POST /negotiate GET /state
Deploy: uvicorn server.server:app --host 0.0.0.0 --port 7860
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from env.negotiation_env import NegotiationDrivingEnv
app = FastAPI(title="Autonomous Driving OpenEnv", version="0.2.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
env = NegotiationDrivingEnv()
obs, _ = env.reset()
# ── Request / Response models ─────────────────────────────────────────
class StepRequest(BaseModel):
action: int # 0=accelerate 1=brake 2=lane_left 3=lane_right
reasoning: str = "" # LLM chain-of-thought (stored for logging)
class NegotiateRequest(BaseModel):
target: str # "blocker" or "traffic"
message: str
# ── Endpoints ─────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok", "version": "0.2.1"}
@app.get("/reset")
def reset():
global obs
obs, info = env.reset()
return {
"observation": obs.tolist(),
"render": env.render(),
"lidar": env.lidar_scan(),
"info": info,
}
@app.post("/step")
def step(req: StepRequest):
global obs
if req.action not in [0, 1, 2, 3]:
raise HTTPException(400, "action must be 0-3")
obs, reward, done, truncated, info = env.step(req.action)
return {
"observation": obs.tolist(),
"reward": round(reward, 4),
"done": done,
"truncated": truncated,
"info": info,
"render": env.render(),
"lidar": env.lidar_scan(),
"collision": env.predict_collision(),
"negotiation_log": env.negotiation_log[-5:],
}
@app.get("/lidar")
def lidar():
return {
**env.lidar_scan(),
**env.predict_collision(),
}
@app.post("/negotiate")
def negotiate(req: NegotiateRequest):
response = env.negotiate(req.target, req.message)
return {
"response": response,
"negotiation_log": env.negotiation_log,
"blocker_yielding": env._blocker_yielding,
}
@app.get("/state")
def state():
return {
"ego": env.ego.tolist(),
"blocker": env.blocker.tolist(),
"traffic": env.traffic.tolist(),
"step": env.step_count,
"render": env.render(),
"memory": env.memory[-5:],
"neg_log": env.negotiation_log,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)