File size: 3,059 Bytes
3d58f38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
OpenEnv-compatible FastAPI Server
Endpoints: GET /reset  POST /step  GET /lidar  POST /negotiate  GET /state

Deploy: uvicorn server.server:app --host 0.0.0.0 --port 7860
"""

import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from env.negotiation_env import NegotiationDrivingEnv

app = FastAPI(title="Autonomous Driving OpenEnv", version="0.2.1")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

env = NegotiationDrivingEnv()
obs, _ = env.reset()


# ── Request / Response models ─────────────────────────────────────────

class StepRequest(BaseModel):
    action: int          # 0=accelerate 1=brake 2=lane_left 3=lane_right
    reasoning: str = ""  # LLM chain-of-thought (stored for logging)

class NegotiateRequest(BaseModel):
    target: str          # "blocker" or "traffic"
    message: str


# ── Endpoints ─────────────────────────────────────────────────────────

@app.get("/health")
def health():
    return {"status": "ok", "version": "0.2.1"}


@app.get("/reset")
def reset():
    global obs
    obs, info = env.reset()
    return {
        "observation": obs.tolist(),
        "render":      env.render(),
        "lidar":       env.lidar_scan(),
        "info":        info,
    }


@app.post("/step")
def step(req: StepRequest):
    global obs
    if req.action not in [0, 1, 2, 3]:
        raise HTTPException(400, "action must be 0-3")

    obs, reward, done, truncated, info = env.step(req.action)

    return {
        "observation":  obs.tolist(),
        "reward":       round(reward, 4),
        "done":         done,
        "truncated":    truncated,
        "info":         info,
        "render":       env.render(),
        "lidar":        env.lidar_scan(),
        "collision":    env.predict_collision(),
        "negotiation_log": env.negotiation_log[-5:],
    }


@app.get("/lidar")
def lidar():
    return {
        **env.lidar_scan(),
        **env.predict_collision(),
    }


@app.post("/negotiate")
def negotiate(req: NegotiateRequest):
    response = env.negotiate(req.target, req.message)
    return {
        "response":        response,
        "negotiation_log": env.negotiation_log,
        "blocker_yielding": env._blocker_yielding,
    }


@app.get("/state")
def state():
    return {
        "ego":       env.ego.tolist(),
        "blocker":   env.blocker.tolist(),
        "traffic":   env.traffic.tolist(),
        "step":      env.step_count,
        "render":    env.render(),
        "memory":    env.memory[-5:],
        "neg_log":   env.negotiation_log,
    }


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)