Spaces:
Configuration error
Configuration error
File size: 4,713 Bytes
c34e7cc facabc7 d1600e6 facabc7 1f7679c d1600e6 c34e7cc 900f69f facabc7 f1982b6 900f69f f1982b6 900f69f c34e7cc 900f69f c34e7cc f1982b6 facabc7 f1982b6 71c1208 f1982b6 71c1208 f1982b6 71c1208 facabc7 f1982b6 facabc7 f1982b6 c34e7cc f1982b6 c34e7cc f1982b6 c34e7cc facabc7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
FastAPI application for the Distributed Infrastructure Environment.
Usage:
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
"""
import asyncio
import json
import os
from openenv.core.env_server.http_server import create_app
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from server.environment import DistributedInfraEnvironment
from server.models import InfraAction, InfraObservation
# --- THE FIX: The Singleton Factory Pattern ---
# 1. Create the environment instance in memory once
_global_env = DistributedInfraEnvironment()
_viz_env = DistributedInfraEnvironment()
_viz_lock = asyncio.Lock()
# 2. Create a "factory function" that returns our active instance
def env_factory():
return _global_env
# 3. Pass the callable factory function to OpenEnv
app = create_app(
env_factory,
InfraAction,
InfraObservation,
env_name="distributed_infra_env",
)
# --- CORS for Next.js frontend ---
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Clear formatted document page for root url
@app.get("/")
def home():
# Safely locate the home.html file in the same directory as this script
html_file_path = os.path.join(os.path.dirname(__file__), "home.html")
with open(html_file_path, "r", encoding="utf-8") as file:
html_content = file.read()
return HTMLResponse(content=html_content)
def _parse_action_payload(payload: dict) -> InfraAction:
"""
Parse a client-sent intervention payload into an InfraAction.
Supports either structured fields or a raw command string.
"""
if "command" in payload and payload["command"]:
return InfraAction(action_type="no_op", raw_command=str(payload["command"]))
action_type = str(payload.get("action_type", "no_op"))
kwargs = {
"action_type": action_type,
"target": payload.get("target"),
"from_node": payload.get("from_node"),
"to_node": payload.get("to_node"),
"rate": payload.get("rate"),
}
return InfraAction(**kwargs)
@app.websocket("/ws/simulation")
async def simulation_socket(websocket: WebSocket):
"""
Stream live DIME observations for visual frontends.
Protocol:
- Server emits a JSON payload roughly every 200ms.
- Client may send optional intervention JSON:
{"command":"kubectl throttle ingress --rate=0.3"}
or structured action fields compatible with InfraAction.
"""
await websocket.accept()
# Ensure the visualization environment has an initialized episode.
async with _viz_lock:
if _viz_env.sim.step_count == 0 and not _viz_env.sim.nodes:
_viz_env.reset(task="cascading_failure")
pending_action: InfraAction | None = None
pending_command: str | None = None
try:
while True:
# Non-blocking receive so we can preserve a fixed tick-rate stream.
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=0.001)
incoming = json.loads(raw) if raw else {}
if isinstance(incoming, dict):
pending_action = _parse_action_payload(incoming)
pending_command = (
str(incoming.get("command"))
if incoming.get("command")
else incoming.get("action_type")
)
except asyncio.TimeoutError:
pass
async with _viz_lock:
action = pending_action or InfraAction(action_type="no_op")
pending_action = None
obs = _viz_env.step(action=action)
if obs.done:
obs = _viz_env.reset(
task=_viz_env.sim.task_id or "cascading_failure"
)
await websocket.send_json(
{
"observation": obs.model_dump(),
"intervention": pending_command,
"last_action_type": _viz_env.sim.last_action_type,
"timestamp_ms": int(asyncio.get_event_loop().time() * 1000),
}
)
pending_command = None
await asyncio.sleep(0.2)
except (WebSocketDisconnect, RuntimeError):
return
def main():
"""Entry point for direct execution."""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|