Spaces:
Configuration error
Configuration error
| """ | |
| FastAPI application for the Distributed Infrastructure Environment. | |
| Usage: | |
| uvicorn server.app:app --reload --host 0.0.0.0 --port 8000 | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| from openenv.core.env_server.http_server import create_app | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from server.environment import DistributedInfraEnvironment | |
| from server.models import InfraAction, InfraObservation | |
| # --- THE FIX: The Singleton Factory Pattern --- | |
| # 1. Create the environment instance in memory once | |
| _global_env = DistributedInfraEnvironment() | |
| _viz_env = DistributedInfraEnvironment() | |
| _viz_lock = asyncio.Lock() | |
| # 2. Create a "factory function" that returns our active instance | |
| def env_factory(): | |
| return _global_env | |
| # 3. Pass the callable factory function to OpenEnv | |
| app = create_app( | |
| env_factory, | |
| InfraAction, | |
| InfraObservation, | |
| env_name="distributed_infra_env", | |
| ) | |
| # --- CORS for Next.js frontend --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Clear formatted document page for root url | |
| def home(): | |
| # Safely locate the home.html file in the same directory as this script | |
| html_file_path = os.path.join(os.path.dirname(__file__), "home.html") | |
| with open(html_file_path, "r", encoding="utf-8") as file: | |
| html_content = file.read() | |
| return HTMLResponse(content=html_content) | |
| def _parse_action_payload(payload: dict) -> InfraAction: | |
| """ | |
| Parse a client-sent intervention payload into an InfraAction. | |
| Supports either structured fields or a raw command string. | |
| """ | |
| if "command" in payload and payload["command"]: | |
| return InfraAction(action_type="no_op", raw_command=str(payload["command"])) | |
| action_type = str(payload.get("action_type", "no_op")) | |
| kwargs = { | |
| "action_type": action_type, | |
| "target": payload.get("target"), | |
| "from_node": payload.get("from_node"), | |
| "to_node": payload.get("to_node"), | |
| "rate": payload.get("rate"), | |
| } | |
| return InfraAction(**kwargs) | |
| async def simulation_socket(websocket: WebSocket): | |
| """ | |
| Stream live DIME observations for visual frontends. | |
| Protocol: | |
| - Server emits a JSON payload roughly every 200ms. | |
| - Client may send optional intervention JSON: | |
| {"command":"kubectl throttle ingress --rate=0.3"} | |
| or structured action fields compatible with InfraAction. | |
| """ | |
| await websocket.accept() | |
| # Ensure the visualization environment has an initialized episode. | |
| async with _viz_lock: | |
| if _viz_env.sim.step_count == 0 and not _viz_env.sim.nodes: | |
| _viz_env.reset(task="cascading_failure") | |
| pending_action: InfraAction | None = None | |
| pending_command: str | None = None | |
| try: | |
| while True: | |
| # Non-blocking receive so we can preserve a fixed tick-rate stream. | |
| try: | |
| raw = await asyncio.wait_for(websocket.receive_text(), timeout=0.001) | |
| incoming = json.loads(raw) if raw else {} | |
| if isinstance(incoming, dict): | |
| pending_action = _parse_action_payload(incoming) | |
| pending_command = ( | |
| str(incoming.get("command")) | |
| if incoming.get("command") | |
| else incoming.get("action_type") | |
| ) | |
| except asyncio.TimeoutError: | |
| pass | |
| async with _viz_lock: | |
| action = pending_action or InfraAction(action_type="no_op") | |
| pending_action = None | |
| obs = _viz_env.step(action=action) | |
| if obs.done: | |
| obs = _viz_env.reset( | |
| task=_viz_env.sim.task_id or "cascading_failure" | |
| ) | |
| await websocket.send_json( | |
| { | |
| "observation": obs.model_dump(), | |
| "intervention": pending_command, | |
| "last_action_type": _viz_env.sim.last_action_type, | |
| "timestamp_ms": int(asyncio.get_event_loop().time() * 1000), | |
| } | |
| ) | |
| pending_command = None | |
| await asyncio.sleep(0.2) | |
| except (WebSocketDisconnect, RuntimeError): | |
| return | |
| def main(): | |
| """Entry point for direct execution.""" | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| if __name__ == "__main__": | |
| main() | |