Lars Talian commited on
Commit
86c1c19
·
1 Parent(s): c51ff20

Implement issues #72 #74 #75 env grounding and console scope

Browse files
src/open_range/server/app.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import logging
6
  import os
7
 
@@ -10,6 +11,24 @@ from fastapi import FastAPI
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def create_app() -> FastAPI:
14
  """Create the OpenRange app through the canonical OpenEnv factory."""
15
  from openenv.core.env_server import create_app as create_openenv_app
@@ -37,6 +56,9 @@ def create_app() -> FastAPI:
37
  RangeObservation,
38
  env_name="open_range",
39
  )
 
 
 
40
 
41
  # Mount custom Gradio dashboard at /web if gradio is available
42
  try:
 
2
 
3
  from __future__ import annotations
4
 
5
+ import inspect
6
  import logging
7
  import os
8
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ def _extract_openenv_server(fastapp: FastAPI) -> object | None:
15
+ """Best-effort extraction of OpenEnv's HTTPEnvServer from route closure."""
16
+ for route in fastapp.router.routes:
17
+ if getattr(route, "path", None) != "/ws":
18
+ continue
19
+ endpoint = getattr(route, "endpoint", None)
20
+ if endpoint is None:
21
+ continue
22
+ try:
23
+ closure = inspect.getclosurevars(endpoint)
24
+ except Exception:
25
+ continue
26
+ server = closure.nonlocals.get("self")
27
+ if server is not None and hasattr(server, "active_sessions"):
28
+ return server
29
+ return None
30
+
31
+
32
  def create_app() -> FastAPI:
33
  """Create the OpenRange app through the canonical OpenEnv factory."""
34
  from openenv.core.env_server import create_app as create_openenv_app
 
56
  RangeObservation,
57
  env_name="open_range",
58
  )
59
+ openenv_server = _extract_openenv_server(fastapp)
60
+ if openenv_server is not None:
61
+ fastapp.state.openenv_server = openenv_server
62
 
63
  # Mount custom Gradio dashboard at /web if gradio is available
64
  try:
src/open_range/server/console.py CHANGED
@@ -47,10 +47,20 @@ def get_history(limit: int = 20) -> list[dict[str, Any]]:
47
  @console_router.get("/api/snapshot")
48
  async def api_snapshot(request: Request) -> JSONResponse:
49
  """Return current snapshot metadata (no truth graph or flags)."""
50
- env = _get_env(request)
 
51
  snapshot = env.snapshot
52
  if snapshot is None:
53
- return JSONResponse({"id": None, "tier": None, "hosts": [], "zones": {}, "vuln_count": 0})
 
 
 
 
 
 
 
 
 
54
 
55
  topo = snapshot.topology if isinstance(snapshot.topology, dict) else {}
56
  hosts = topo.get("hosts", [])
@@ -64,19 +74,26 @@ async def api_snapshot(request: Request) -> JSONResponse:
64
  "hosts": hosts,
65
  "zones": zones,
66
  "vuln_count": vuln_count,
 
 
 
67
  })
68
 
69
 
70
  @console_router.get("/api/episode")
71
  async def api_episode(request: Request) -> JSONResponse:
72
  """Return current episode state."""
73
- env = _get_env(request)
 
74
  state = env.state
75
  return JSONResponse({
76
  "step_count": state.step_count,
77
  "flags_found": len(state.flags_found),
78
  "mode": state.mode,
79
  "services_status": state.services_status,
 
 
 
80
  })
81
 
82
 
@@ -98,20 +115,70 @@ async def console_page() -> HTMLResponse:
98
  # ---------------------------------------------------------------------------
99
 
100
 
101
- def _get_env(request: Request) -> Any:
102
- """Retrieve the RangeEnvironment from the app's state.
103
 
104
- The app.py startup stores the environment instance as ``app.state.env``.
105
- If that attribute is missing we fall back to importing a fresh one.
 
 
106
  """
107
  app = request.app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if hasattr(app.state, "env"):
109
- return app.state.env
110
- # Fallback: create an ephemeral environment (tests, etc.)
 
 
 
 
 
 
 
 
 
111
  from open_range.server.environment import RangeEnvironment
 
112
  if not hasattr(app.state, "_fallback_env"):
113
  app.state._fallback_env = RangeEnvironment(docker_available=False)
114
- return app.state._fallback_env
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  # ---------------------------------------------------------------------------
 
47
  @console_router.get("/api/snapshot")
48
  async def api_snapshot(request: Request) -> JSONResponse:
49
  """Return current snapshot metadata (no truth graph or flags)."""
50
+ ctx = _get_env_context(request)
51
+ env = ctx["env"]
52
  snapshot = env.snapshot
53
  if snapshot is None:
54
+ return JSONResponse({
55
+ "id": None,
56
+ "tier": None,
57
+ "hosts": [],
58
+ "zones": {},
59
+ "vuln_count": 0,
60
+ "state_scope": ctx["state_scope"],
61
+ "session_id": ctx["session_id"],
62
+ "warning": ctx["warning"],
63
+ })
64
 
65
  topo = snapshot.topology if isinstance(snapshot.topology, dict) else {}
66
  hosts = topo.get("hosts", [])
 
74
  "hosts": hosts,
75
  "zones": zones,
76
  "vuln_count": vuln_count,
77
+ "state_scope": ctx["state_scope"],
78
+ "session_id": ctx["session_id"],
79
+ "warning": ctx["warning"],
80
  })
81
 
82
 
83
  @console_router.get("/api/episode")
84
  async def api_episode(request: Request) -> JSONResponse:
85
  """Return current episode state."""
86
+ ctx = _get_env_context(request)
87
+ env = ctx["env"]
88
  state = env.state
89
  return JSONResponse({
90
  "step_count": state.step_count,
91
  "flags_found": len(state.flags_found),
92
  "mode": state.mode,
93
  "services_status": state.services_status,
94
+ "state_scope": ctx["state_scope"],
95
+ "session_id": ctx["session_id"],
96
+ "warning": ctx["warning"],
97
  })
98
 
99
 
 
115
  # ---------------------------------------------------------------------------
116
 
117
 
118
+ def _get_env_context(request: Request) -> dict[str, Any]:
119
+ """Resolve the environment context used by the console endpoints.
120
 
121
+ Priority:
122
+ 1. Active OpenEnv WebSocket session environment (session-scoped truth)
123
+ 2. ``app.state.env`` fallback environment (global app scope)
124
+ 3. Lazily created fallback environment (tests/dev)
125
  """
126
  app = request.app
127
+
128
+ server = getattr(app.state, "openenv_server", None)
129
+ sessions = getattr(server, "_sessions", None)
130
+ if isinstance(sessions, dict) and sessions:
131
+ if len(sessions) == 1:
132
+ session_id, env = next(iter(sessions.items()))
133
+ return {
134
+ "env": env,
135
+ "state_scope": "websocket_session",
136
+ "session_id": session_id,
137
+ "warning": None,
138
+ }
139
+
140
+ session_info = getattr(server, "_session_info", {})
141
+ selected_id = max(
142
+ sessions.keys(),
143
+ key=lambda sid: float(getattr(session_info.get(sid), "last_activity_at", 0.0) or 0.0),
144
+ )
145
+ return {
146
+ "env": sessions[selected_id],
147
+ "state_scope": "websocket_session",
148
+ "session_id": selected_id,
149
+ "warning": (
150
+ f"{len(sessions)} active sessions detected; "
151
+ f"showing the most recently active session ({selected_id})."
152
+ ),
153
+ }
154
+
155
  if hasattr(app.state, "env"):
156
+ return {
157
+ "env": app.state.env,
158
+ "state_scope": "app_state_env",
159
+ "session_id": None,
160
+ "warning": (
161
+ "No active WebSocket session found; console is showing shared "
162
+ "app-state environment data."
163
+ ),
164
+ }
165
+
166
+ # Fallback: create an ephemeral environment (tests/dev)
167
  from open_range.server.environment import RangeEnvironment
168
+
169
  if not hasattr(app.state, "_fallback_env"):
170
  app.state._fallback_env = RangeEnvironment(docker_available=False)
171
+ return {
172
+ "env": app.state._fallback_env,
173
+ "state_scope": "fallback_env",
174
+ "session_id": None,
175
+ "warning": "Console is using a fallback environment (no server session available).",
176
+ }
177
+
178
+
179
+ def _get_env(request: Request) -> Any:
180
+ """Compatibility helper for callers that only need the env object."""
181
+ return _get_env_context(request)["env"]
182
 
183
 
184
  # ---------------------------------------------------------------------------
src/open_range/server/environment.py CHANGED
@@ -1236,6 +1236,62 @@ class RangeEnvironment(Environment[RangeAction, RangeObservation, RangeState]):
1236
  )
1237
  return self._container_name(name)
1238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
  # -----------------------------------------------------------------
1240
  # Core API
1241
  # -----------------------------------------------------------------
@@ -1313,6 +1369,9 @@ class RangeEnvironment(Environment[RangeAction, RangeObservation, RangeState]):
1313
  # Start NPC traffic for this episode
1314
  self._start_npcs(self._snapshot)
1315
 
 
 
 
1316
  # Build initial briefing
1317
  task = self._snapshot.task
1318
  if isinstance(task, dict):
@@ -1395,6 +1454,7 @@ class RangeEnvironment(Environment[RangeAction, RangeObservation, RangeState]):
1395
 
1396
  if cmd_name in meta_handlers:
1397
  obs = meta_handlers[cmd_name](action)
 
1398
  obs = self._apply_rewards(action, obs)
1399
  self._check_termination(obs)
1400
  self._report_if_done(obs)
@@ -1439,6 +1499,7 @@ class RangeEnvironment(Environment[RangeAction, RangeObservation, RangeState]):
1439
 
1440
  # Refresh NPC traffic log for reward computation
1441
  self._refresh_npc_traffic_log()
 
1442
 
1443
  # Build observation
1444
  obs = RangeObservation(
@@ -1574,8 +1635,8 @@ class RangeEnvironment(Environment[RangeAction, RangeObservation, RangeState]):
1574
 
1575
  In production (docker or subprocess mode with real infrastructure),
1576
  queries the SIEM container for actual log-based alerts. Falls back
1577
- to synthetic alerts derived from ALL Red actions when SIEM queries
1578
- return nothing or in unit-test mock mode.
1579
  """
1580
  # Try real SIEM query in non-mock modes
1581
  if self._docker_available is not False or self._execution_mode == "subprocess":
@@ -1583,7 +1644,23 @@ class RangeEnvironment(Environment[RangeAction, RangeObservation, RangeState]):
1583
  if siem_alerts:
1584
  return siem_alerts
1585
 
1586
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1587
 
1588
  # -----------------------------------------------------------------
1589
  # Introspection (for reward computation and debugging)
 
1236
  )
1237
  return self._container_name(name)
1238
 
1239
+ def _topology_host_names(self) -> list[str]:
1240
+ """Return deduplicated host names from the active snapshot topology."""
1241
+ if not self._snapshot or not isinstance(self._snapshot.topology, dict):
1242
+ return []
1243
+ hosts = self._snapshot.topology.get("hosts", [])
1244
+ names: list[str] = []
1245
+ for host in hosts:
1246
+ if isinstance(host, str):
1247
+ candidate = host
1248
+ elif isinstance(host, dict):
1249
+ candidate = host.get("name") or host.get("hostname") or ""
1250
+ else:
1251
+ candidate = ""
1252
+ name = str(candidate).strip()
1253
+ if name and name not in names:
1254
+ names.append(name)
1255
+ return names
1256
+
1257
+ def _refresh_services_status(self) -> None:
1258
+ """Refresh ``state.services_status`` from runtime/container health.
1259
+
1260
+ Availability reward should never rely on an empty status map after reset.
1261
+ When health cannot be verified, host status is marked ``"unknown"``.
1262
+ """
1263
+ host_names = self._topology_host_names()
1264
+ if not host_names:
1265
+ self._state.services_status = {}
1266
+ return
1267
+
1268
+ status_map = {host: "unknown" for host in host_names}
1269
+
1270
+ if self._execution_mode == "docker" and self._docker_available is not False:
1271
+ client = self._get_docker()
1272
+ if client is not None:
1273
+ for host in host_names:
1274
+ container_name = self._container_name(host)
1275
+ try:
1276
+ container = client.containers.get(container_name)
1277
+ status_map[host] = str(getattr(container, "status", "unknown") or "unknown")
1278
+ except Exception:
1279
+ status_map[host] = "down"
1280
+ self._state.services_status = status_map
1281
+ return
1282
+
1283
+ if self._execution_mode == "subprocess" and self._snapshot and self._snapshot.services:
1284
+ checks_by_host: dict[str, list[bool]] = {}
1285
+ for svc in self._snapshot.services:
1286
+ host = str(getattr(svc, "host", "") or "").strip()
1287
+ if not host:
1288
+ continue
1289
+ checks_by_host.setdefault(host, []).append(self._probe_readiness(svc.readiness))
1290
+ for host, checks in checks_by_host.items():
1291
+ status_map[host] = "healthy" if checks and all(checks) else "degraded"
1292
+
1293
+ self._state.services_status = status_map
1294
+
1295
  # -----------------------------------------------------------------
1296
  # Core API
1297
  # -----------------------------------------------------------------
 
1369
  # Start NPC traffic for this episode
1370
  self._start_npcs(self._snapshot)
1371
 
1372
+ # Prime service health map for availability reward grounding.
1373
+ self._refresh_services_status()
1374
+
1375
  # Build initial briefing
1376
  task = self._snapshot.task
1377
  if isinstance(task, dict):
 
1454
 
1455
  if cmd_name in meta_handlers:
1456
  obs = meta_handlers[cmd_name](action)
1457
+ self._refresh_services_status()
1458
  obs = self._apply_rewards(action, obs)
1459
  self._check_termination(obs)
1460
  self._report_if_done(obs)
 
1499
 
1500
  # Refresh NPC traffic log for reward computation
1501
  self._refresh_npc_traffic_log()
1502
+ self._refresh_services_status()
1503
 
1504
  # Build observation
1505
  obs = RangeObservation(
 
1635
 
1636
  In production (docker or subprocess mode with real infrastructure),
1637
  queries the SIEM container for actual log-based alerts. Falls back
1638
+ to synthetic alerts derived from Red action history when SIEM queries
1639
+ return nothing or in unit-test mock mode (capped to recent 20 lines).
1640
  """
1641
  # Try real SIEM query in non-mock modes
1642
  if self._docker_available is not False or self._execution_mode == "subprocess":
 
1644
  if siem_alerts:
1645
  return siem_alerts
1646
 
1647
+ # Fallback: synthesize alerts from recent Red actions so Blue still
1648
+ # receives actionable signal in mock/degraded SIEM paths.
1649
+ synthetic: list[str] = []
1650
+ for record in self._red_history:
1651
+ if record.get("type") in ("hallucinated_flag", "evidence"):
1652
+ continue
1653
+ command = str(record.get("command", "")).strip()
1654
+ if not command:
1655
+ continue
1656
+ step = record.get("step", "?")
1657
+ cmd_name = str(record.get("cmd_name", "")).strip() or _extract_command_name(command)
1658
+ target = str(record.get("target", "")).strip()
1659
+ if target:
1660
+ synthetic.append(f"[synthetic] step={step} cmd={cmd_name} target={target} :: {command}")
1661
+ else:
1662
+ synthetic.append(f"[synthetic] step={step} cmd={cmd_name} :: {command}")
1663
+ return synthetic[-20:]
1664
 
1665
  # -----------------------------------------------------------------
1666
  # Introspection (for reward computation and debugging)
tests/test_console_context.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for console environment context resolution."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from types import SimpleNamespace
6
+
7
+ from open_range.server.console import _get_env_context
8
+ from open_range.server.environment import RangeEnvironment
9
+
10
+
11
+ class _Req:
12
+ def __init__(self, app):
13
+ self.app = app
14
+
15
+
16
+ def _app_with_state(**kwargs):
17
+ return SimpleNamespace(state=SimpleNamespace(**kwargs))
18
+
19
+
20
+ def test_prefers_active_websocket_session_env():
21
+ fallback_env = RangeEnvironment(docker_available=False)
22
+ ws_env = RangeEnvironment(docker_available=False)
23
+ server = SimpleNamespace(
24
+ _sessions={"session_a": ws_env},
25
+ _session_info={"session_a": SimpleNamespace(last_activity_at=10.0)},
26
+ )
27
+ request = _Req(_app_with_state(env=fallback_env, openenv_server=server))
28
+
29
+ ctx = _get_env_context(request)
30
+ assert ctx["env"] is ws_env
31
+ assert ctx["state_scope"] == "websocket_session"
32
+ assert ctx["session_id"] == "session_a"
33
+ assert ctx["warning"] is None
34
+
35
+
36
+ def test_uses_app_state_env_when_no_active_session():
37
+ fallback_env = RangeEnvironment(docker_available=False)
38
+ server = SimpleNamespace(_sessions={}, _session_info={})
39
+ request = _Req(_app_with_state(env=fallback_env, openenv_server=server))
40
+
41
+ ctx = _get_env_context(request)
42
+ assert ctx["env"] is fallback_env
43
+ assert ctx["state_scope"] == "app_state_env"
44
+ assert ctx["session_id"] is None
45
+ assert isinstance(ctx["warning"], str) and ctx["warning"]
46
+
47
+
48
+ def test_multiple_sessions_selects_most_recent_and_warns():
49
+ older_env = RangeEnvironment(docker_available=False)
50
+ newer_env = RangeEnvironment(docker_available=False)
51
+ server = SimpleNamespace(
52
+ _sessions={"old": older_env, "new": newer_env},
53
+ _session_info={
54
+ "old": SimpleNamespace(last_activity_at=10.0),
55
+ "new": SimpleNamespace(last_activity_at=20.0),
56
+ },
57
+ )
58
+ request = _Req(_app_with_state(openenv_server=server))
59
+
60
+ ctx = _get_env_context(request)
61
+ assert ctx["env"] is newer_env
62
+ assert ctx["state_scope"] == "websocket_session"
63
+ assert ctx["session_id"] == "new"
64
+ assert "active sessions" in (ctx["warning"] or "").lower()
tests/test_environment.py CHANGED
@@ -66,6 +66,12 @@ class TestReset:
66
  assert isinstance(obs, RangeObservation)
67
  assert env.snapshot is not None
68
 
 
 
 
 
 
 
69
 
70
  class TestTargetResolution:
71
  """Target selection should honor manifest-compiled metadata."""
@@ -187,6 +193,14 @@ class TestBlueStep:
187
  obs = env.step(RangeAction(command="", mode="blue"))
188
  assert obs.stderr != ""
189
 
 
 
 
 
 
 
 
 
190
  def test_step_passes_timeout_override_to_executor(self):
191
  env = RangeEnvironment(docker_available=False)
192
  env.reset(snapshot=_MINIMAL_SNAPSHOT)
 
66
  assert isinstance(obs, RangeObservation)
67
  assert env.snapshot is not None
68
 
69
+ def test_reset_initializes_services_status_from_topology_hosts(self):
70
+ env = RangeEnvironment(docker_available=False)
71
+ env.reset(snapshot=_MINIMAL_SNAPSHOT)
72
+ # In mock mode service health is unknown, but hosts should be tracked.
73
+ assert set(env.state.services_status.keys()) == {"attacker", "siem"}
74
+
75
 
76
  class TestTargetResolution:
77
  """Target selection should honor manifest-compiled metadata."""
 
193
  obs = env.step(RangeAction(command="", mode="blue"))
194
  assert obs.stderr != ""
195
 
196
+ def test_blue_alerts_fall_back_to_synthetic_red_history(self):
197
+ env = RangeEnvironment(docker_available=False)
198
+ env.reset(snapshot=_MINIMAL_SNAPSHOT)
199
+ env.step(RangeAction(command="nmap -sV web", mode="red"))
200
+ obs = env.step(RangeAction(command="tail -n 50 /var/log/siem/all.log", mode="blue"))
201
+ assert obs.alerts
202
+ assert any("synthetic" in alert.lower() for alert in obs.alerts)
203
+
204
  def test_step_passes_timeout_override_to_executor(self):
205
  env = RangeEnvironment(docker_available=False)
206
  env.reset(snapshot=_MINIMAL_SNAPSHOT)