File size: 6,833 Bytes
4904e85
 
 
 
f4ed234
4904e85
1df309d
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f212fd
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7423aa7
1df309d
 
 
910875a
 
1df309d
910875a
 
 
 
 
1df309d
 
910875a
 
 
 
 
 
7423aa7
 
4904e85
 
8c359c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d762f3
 
f4ed234
 
 
 
 
 
 
 
 
1d762f3
 
f4ed234
 
1d762f3
f4ed234
1d762f3
f4ed234
1d762f3
 
f4ed234
 
 
 
1d762f3
4904e85
 
e259b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4904e85
2f212fd
 
 
 
 
 
 
 
 
 
 
4904e85
2f212fd
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14170d7
 
 
 
 
 
 
 
 
 
 
 
 
4904e85
 
 
 
 
6172160
4904e85
 
 
 
 
6172160
4904e85
 
 
 
 
43f2683
4904e85
f4ed234
43f2683
14170d7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""OpenEnv server implementing reset/step/state endpoints."""

from typing import Any

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel

from src.models import Action, Observation, State
from src.openenv_environment import OpenEnvEnvironment

app = FastAPI(title="911 — Dispatch API")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

_env: OpenEnvEnvironment | None = None


# Removed ResetRequest since /reset now dynamically parses the Request to handle null bodies gracefully.

class StepRequest(BaseModel):
    action: dict[str, Any]


class StepResponse(BaseModel):
    observation: dict[str, Any]
    reward: float
    done: bool


@app.exception_handler(RuntimeError)
async def runtime_error_handler(request, exc: RuntimeError):
    from fastapi.responses import JSONResponse

    return JSONResponse(status_code=500, content={"detail": str(exc)})


@app.get("/", include_in_schema=False)
async def root():
    """Serve the live dashboard on the root route for HF Spaces."""
    import os
    current_file = os.path.abspath(__file__)
    base_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
    dashboard_path = os.path.join(base_dir, "live_dashboard.html")
    
    # Fallback to current working directory if not found
    if not os.path.exists(dashboard_path):
        dashboard_path = os.path.join(os.getcwd(), "live_dashboard.html")

    if os.path.exists(dashboard_path):
        return FileResponse(dashboard_path)
        
    return JSONResponse({
        "status": "healthy", 
        "error": "dashboard not found",
        "debug": dashboard_path
    })


@app.get("/health")
async def health() -> dict[str, str]:
    # OpenEnv runtime validation expects status=healthy
    return {"status": "healthy"}


@app.get("/metadata")
async def metadata() -> dict[str, Any]:
    """OpenEnv metadata endpoint used by runtime validators."""

    return {
        "name": "citywide-dispatch-supervisor",
        "description": (
            "City-wide 911 emergency dispatch supervisor RL environment. "
            "An LLM agent learns to manage simultaneous incidents by dispatching "
            "police, fire, and EMS units across a city grid under realistic constraints."
        ),
        "version": "0.1.0",
        "mode": "simulation",
    }


@app.get("/schema")
async def schema() -> dict[str, Any]:
    """Return JSON schemas for Action/Observation/State."""

    return {
        "action": Action.model_json_schema(),
        "observation": Observation.model_json_schema(),
        "state": State.model_json_schema(),
    }


@app.post("/mcp")
async def mcp_endpoint(request: Request):
    """MCP JSON-RPC passthrough for OpenEnv runtime compatibility."""
    try:
        body = await request.json()
    except Exception:
        return JSONResponse({"error": "invalid JSON"}, status_code=400)

    method = body.get("method", "")
    req_id = body.get("id", 1)

    if method == "reset":
        result = await _env.reset()
        return {"jsonrpc": "2.0", "id": req_id, "result": result.model_dump()}
    elif method == "step":
        action_data = body.get("params", {}).get("action", {})
        action = Action(**action_data)
        obs, reward, done = await _env.step(action)
        return {"jsonrpc": "2.0", "id": req_id, "result": {"observation": obs.model_dump(), "reward": reward, "done": done}}
    elif method == "state":
        result = _env.state()
        return {"jsonrpc": "2.0", "id": req_id, "result": result.model_dump()}
    elif method == "legal_actions":
        actions = _env.legal_actions()
        return {"jsonrpc": "2.0", "id": req_id, "result": [a.model_dump() for a in actions]}
    else:
        return JSONResponse({"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}}, status_code=404)


@app.get("/tasks")
async def list_tasks() -> list[dict[str, str]]:
    """List all available tasks."""
    from src.tasks.registry import TaskRegistry

    return [
        {
            "task_id": t.task_id,
            "name": t.name,
            "description": t.description,
            "difficulty": t.difficulty,
        }
        for t in TaskRegistry.list_tasks()
    ]


@app.post("/reset")
async def reset(request: Request) -> dict[str, Any]:
    try:
        body = await request.json()
    except Exception:
        body = {}
    if body is None:
        body = {}
        
    task_id = body.get("task_id", "single_incident")
    seed = body.get("seed", None)

    global _env
    _env = OpenEnvEnvironment(task_id=task_id, seed=seed)
    obs = await _env.reset()
    return obs.model_dump()


@app.post("/step")
async def step(request: StepRequest) -> StepResponse:
    if _env is None:
        raise RuntimeError("Environment not initialized. Call /reset first.")
    try:
        action = Action.model_validate(request.action)
    except Exception as e:
        raise RuntimeError(f"Invalid action: {e}")
    obs, reward, done = await _env.step(action)
    return StepResponse(
        observation=obs.model_dump(),
        reward=reward,
        done=done,
    )


@app.get("/state")
async def get_state() -> dict[str, Any]:
    if _env is None:
        raise RuntimeError("Environment not initialized. Call /reset first.")
    state = _env.state()
    return state.model_dump()


@app.get("/dashboard/state")
async def get_dashboard_state() -> dict[str, Any]:
    """Extended state for the HTML live dashboard.

    Keeps the existing /state response stable for typed clients.
    """
    if _env is None:
        # Return an empty but valid structure before /reset is called
        return {
            "units": {},
            "incidents": {},
            "episode_id": "not-initialized",
            "step_count": 0,
            "task_id": "none",
            "city_time": 0.0,
            "metadata": {},
            "legal_actions": [],
            "issues": [],
            "observation": None,
        }

    state_dict = _env.state().model_dump()
    legal_actions = [a.model_dump() for a in _env.legal_actions()]
    last_obs = _env.last_observation()
    issues = list(last_obs.issues) if last_obs is not None else []
    obs_dict = last_obs.model_dump() if last_obs is not None else None

    return {
        **state_dict,
        "legal_actions": legal_actions,
        "issues": issues,
        "observation": obs_dict,
    }


def main():
    import uvicorn
    import os

    port = int(os.environ.get("PORT", "7860"))
    uvicorn.run("src.server.app:app", host="0.0.0.0", port=port, reload=False)


if __name__ == "__main__":
    main()