File size: 4,713 Bytes
c34e7cc
 
 
 
 
 
 
facabc7
 
 
 
d1600e6
facabc7
 
1f7679c
d1600e6
 
 
c34e7cc
900f69f
 
 
facabc7
 
 
f1982b6
900f69f
 
 
 
f1982b6
900f69f
c34e7cc
900f69f
c34e7cc
 
 
 
 
f1982b6
 
 
 
 
 
 
 
facabc7
f1982b6
 
71c1208
 
 
 
f1982b6
71c1208
 
f1982b6
71c1208
 
facabc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1982b6
 
 
facabc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1982b6
c34e7cc
 
 
f1982b6
c34e7cc
 
f1982b6
c34e7cc
facabc7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
FastAPI application for the Distributed Infrastructure Environment.

Usage:
    uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
"""

import asyncio
import json
import os

from openenv.core.env_server.http_server import create_app
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware

from server.environment import DistributedInfraEnvironment
from server.models import InfraAction, InfraObservation

# --- THE FIX: The Singleton Factory Pattern ---
# 1. Create the environment instance in memory once
_global_env = DistributedInfraEnvironment()
_viz_env = DistributedInfraEnvironment()
_viz_lock = asyncio.Lock()


# 2. Create a "factory function" that returns our active instance
def env_factory():
    return _global_env


# 3. Pass the callable factory function to OpenEnv
app = create_app(
    env_factory,
    InfraAction,
    InfraObservation,
    env_name="distributed_infra_env",
)

# --- CORS for Next.js frontend ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Clear formatted document page for root url
@app.get("/")
def home():
    # Safely locate the home.html file in the same directory as this script
    html_file_path = os.path.join(os.path.dirname(__file__), "home.html")

    with open(html_file_path, "r", encoding="utf-8") as file:
        html_content = file.read()

    return HTMLResponse(content=html_content)


def _parse_action_payload(payload: dict) -> InfraAction:
    """
    Parse a client-sent intervention payload into an InfraAction.

    Supports either structured fields or a raw command string.
    """
    if "command" in payload and payload["command"]:
        return InfraAction(action_type="no_op", raw_command=str(payload["command"]))

    action_type = str(payload.get("action_type", "no_op"))
    kwargs = {
        "action_type": action_type,
        "target": payload.get("target"),
        "from_node": payload.get("from_node"),
        "to_node": payload.get("to_node"),
        "rate": payload.get("rate"),
    }
    return InfraAction(**kwargs)


@app.websocket("/ws/simulation")
async def simulation_socket(websocket: WebSocket):
    """
    Stream live DIME observations for visual frontends.

    Protocol:
    - Server emits a JSON payload roughly every 200ms.
    - Client may send optional intervention JSON:
      {"command":"kubectl throttle ingress --rate=0.3"}
      or structured action fields compatible with InfraAction.
    """
    await websocket.accept()

    # Ensure the visualization environment has an initialized episode.
    async with _viz_lock:
        if _viz_env.sim.step_count == 0 and not _viz_env.sim.nodes:
            _viz_env.reset(task="cascading_failure")

    pending_action: InfraAction | None = None
    pending_command: str | None = None

    try:
        while True:
            # Non-blocking receive so we can preserve a fixed tick-rate stream.
            try:
                raw = await asyncio.wait_for(websocket.receive_text(), timeout=0.001)
                incoming = json.loads(raw) if raw else {}
                if isinstance(incoming, dict):
                    pending_action = _parse_action_payload(incoming)
                    pending_command = (
                        str(incoming.get("command"))
                        if incoming.get("command")
                        else incoming.get("action_type")
                    )
            except asyncio.TimeoutError:
                pass

            async with _viz_lock:
                action = pending_action or InfraAction(action_type="no_op")
                pending_action = None

                obs = _viz_env.step(action=action)
                if obs.done:
                    obs = _viz_env.reset(
                        task=_viz_env.sim.task_id or "cascading_failure"
                    )

                await websocket.send_json(
                    {
                        "observation": obs.model_dump(),
                        "intervention": pending_command,
                        "last_action_type": _viz_env.sim.last_action_type,
                        "timestamp_ms": int(asyncio.get_event_loop().time() * 1000),
                    }
                )
                pending_command = None

            await asyncio.sleep(0.2)
    except (WebSocketDisconnect, RuntimeError):
        return


def main():
    """Entry point for direct execution."""
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
    main()