Aaron Brown commited on
Commit
a24d0f2
·
1 Parent(s): 4f46230

Remove hardcoded infra, enforce snapshot-driven architecture

Browse files

- environment.py: Remove hardcoded MySQL password (now from snapshot
topology users), hardcoded container routing (now from topology
host roles/zones), mock mode from production paths, hardcoded
alert tool names (now queries real SIEM). Container name resolution
raises on failure instead of silent fallback.
- zone_router.py: Remove all hardcoded ZONE_ROUTES/HOST_ZONES/HOST_PORTS
constants. ZoneRouter is now purely snapshot-driven via from_snapshot().
Fail-closed for unknown zones. Added from_manifest() classmethod.
- mutator.py: Add post-build diversity enforcement (rejects snapshots
repeating vuln classes from last 3 episodes or injection points from
last 5). Remove hardcoded _INJECTION_POINTS dict (now dynamic).
- npc_manager.py: Remove hardcoded _SCRIPT_CONTAINER_MAP and static
env vars. Derive container mapping and scripts from snapshot topology.
- runtime.py: Remove hardcoded MySQL password, add snapshot diversity
tracking in acquire_snapshot(), add curriculum-driven tier escalation.

src/open_range/builder/mutator.py CHANGED
@@ -37,17 +37,6 @@ _SUPPORTED_MUTATION_OPS = {
37
  "add_benign_noise",
38
  }
39
 
40
- _INJECTION_POINTS = {
41
- "sqli": "/legacy/search.php?q=",
42
- "idor": "/api/users/{id}",
43
- "path_traversal": "/download?file=",
44
- "command_injection": "/admin/diagnostics?host=",
45
- "ssrf": "/fetch?url=",
46
- "weak_creds": "ssh svc_app@host",
47
- "broken_auth": "/admin/login",
48
- "xss": "/search?q=",
49
- }
50
-
51
 
52
  class Mutator:
53
  """Orchestrate vuln mutation across resets.
@@ -120,17 +109,46 @@ class Mutator:
120
  except (AttributeError, ValueError):
121
  pass # protocol version without error field
122
 
123
- if parent_snapshot is None:
124
- snapshot = await self.builder.build(manifest, context)
125
- snapshot = self._hydrate_root_snapshot(snapshot, manifest)
126
- else:
127
- snapshot = self._mutate_parent_snapshot(
128
- manifest=manifest,
129
- parent_snapshot=parent_snapshot,
130
- parent_snapshot_id=parent_snapshot_id,
131
- context=context,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  )
133
 
 
 
 
 
 
 
 
 
 
134
  # Update history
135
  new_classes = [v.type for v in snapshot.truth_graph.vulns]
136
  self._history.extend(new_classes)
@@ -157,6 +175,55 @@ class Mutator:
157
  """All vuln classes used so far, in order."""
158
  return list(self._history)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def _hydrate_root_snapshot(
161
  self,
162
  snapshot: SnapshotSpec,
@@ -562,7 +629,7 @@ class Mutator:
562
  type=vuln_type,
563
  host=host,
564
  service=service,
565
- injection_point=_INJECTION_POINTS.get(vuln_type, f"/debug/{vuln_type}"),
566
  vulnerable_code=f"// mutation-added {vuln_type} surface on {host}",
567
  root_cause=f"Mutation introduced {vuln_type} on {host}",
568
  blast_radius=f"Additional foothold on {host}",
 
37
  "add_benign_noise",
38
  }
39
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  class Mutator:
42
  """Orchestrate vuln mutation across resets.
 
109
  except (AttributeError, ValueError):
110
  pass # protocol version without error field
111
 
112
+ # Build with diversity enforcement -- retry up to 3 times if the
113
+ # snapshot repeats recent vuln classes or injection points.
114
+ max_diversity_retries = 3
115
+ snapshot: SnapshotSpec | None = None
116
+ last_reason = ""
117
+
118
+ for attempt in range(1, max_diversity_retries + 1):
119
+ if parent_snapshot is None:
120
+ candidate = await self.builder.build(manifest, context)
121
+ candidate = self._hydrate_root_snapshot(candidate, manifest)
122
+ else:
123
+ candidate = self._mutate_parent_snapshot(
124
+ manifest=manifest,
125
+ parent_snapshot=parent_snapshot,
126
+ parent_snapshot_id=parent_snapshot_id,
127
+ context=context,
128
+ )
129
+
130
+ passes, reason = self._check_diversity(candidate, manifest)
131
+ if passes:
132
+ snapshot = candidate
133
+ break
134
+
135
+ last_reason = reason
136
+ logger.info(
137
+ "Mutator: diversity check failed on attempt %d/%d: %s",
138
+ attempt,
139
+ max_diversity_retries,
140
+ reason,
141
  )
142
 
143
+ if snapshot is None:
144
+ # Exhausted retries -- accept last candidate with a warning
145
+ logger.warning(
146
+ "Mutator: accepting snapshot after %d diversity retries; last failure: %s",
147
+ max_diversity_retries,
148
+ last_reason,
149
+ )
150
+ snapshot = candidate # type: ignore[possibly-undefined]
151
+
152
  # Update history
153
  new_classes = [v.type for v in snapshot.truth_graph.vulns]
154
  self._history.extend(new_classes)
 
175
  """All vuln classes used so far, in order."""
176
  return list(self._history)
177
 
178
+ def _check_diversity(
179
+ self,
180
+ snapshot: SnapshotSpec,
181
+ manifest: dict[str, Any],
182
+ ) -> tuple[bool, str]:
183
+ """Check whether *snapshot* meets vuln diversity constraints.
184
+
185
+ Returns:
186
+ ``(passes, reason)`` -- *passes* is ``True`` when the snapshot
187
+ satisfies the diversity rules; *reason* explains why it failed.
188
+ """
189
+ new_classes = [v.type for v in snapshot.truth_graph.vulns]
190
+ new_surfaces = [v.injection_point for v in snapshot.truth_graph.vulns]
191
+
192
+ recent_classes = set(self._history[-3:]) if self._history else set()
193
+ recent_surfaces = set(self._attack_surfaces[-5:]) if self._attack_surfaces else set()
194
+
195
+ all_families = {str(v) for v in manifest.get("bug_families", []) if v}
196
+
197
+ # --- vuln class check ---
198
+ if new_classes and recent_classes:
199
+ new_class_set = set(new_classes)
200
+ if new_class_set and new_class_set.issubset(recent_classes):
201
+ # Only reject if there ARE alternative families we could use
202
+ alternatives = all_families - recent_classes
203
+ if alternatives:
204
+ return (
205
+ False,
206
+ f"All vuln classes {sorted(new_class_set)} repeat recent history "
207
+ f"{sorted(recent_classes)}; alternatives available: {sorted(alternatives)}",
208
+ )
209
+
210
+ # --- injection point check ---
211
+ if new_surfaces and recent_surfaces:
212
+ new_surface_set = set(new_surfaces)
213
+ if new_surface_set and new_surface_set.issubset(recent_surfaces):
214
+ # Only reject if the manifest has enough families to allow
215
+ # different surfaces (any alternative family would produce a
216
+ # different dynamic injection point)
217
+ alternatives = all_families - set(new_classes)
218
+ if alternatives:
219
+ return (
220
+ False,
221
+ f"All injection points {sorted(new_surface_set)} repeat recent surfaces "
222
+ f"{sorted(recent_surfaces)}; alternatives available: {sorted(alternatives)}",
223
+ )
224
+
225
+ return (True, "")
226
+
227
  def _hydrate_root_snapshot(
228
  self,
229
  snapshot: SnapshotSpec,
 
629
  type=vuln_type,
630
  host=host,
631
  service=service,
632
+ injection_point=f"/{service or 'app'}/{vuln_type}",
633
  vulnerable_code=f"// mutation-added {vuln_type} surface on {host}",
634
  root_cause=f"Mutation introduced {vuln_type} on {host}",
635
  blast_radius=f"Additional foothold on {host}",
src/open_range/builder/npc/npc_manager.py CHANGED
@@ -4,11 +4,17 @@ Starts Level 0 shell-script traffic generators and (optionally) Level 1
4
  LLM-driven NPC agents for a given snapshot. Multimodal NPC channels
5
  (chat, voice, document) are initialised at start and their activity logs
6
  are available for SIEM consumption.
 
 
 
 
 
7
  """
8
 
9
  from __future__ import annotations
10
 
11
  import asyncio
 
12
  import logging
13
  from pathlib import Path
14
  from typing import Any
@@ -20,14 +26,119 @@ logger = logging.getLogger(__name__)
20
 
21
  _SCRIPT_DIR = Path(__file__).parent
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class NPCManager:
25
  """Start and stop NPC background traffic for a snapshot."""
26
 
27
- def __init__(self) -> None:
 
28
  self._processes: list[asyncio.subprocess.Process] = []
29
  self._tasks: list[asyncio.Task[Any]] = []
30
  self._running = False
 
 
 
 
 
31
 
32
  # Multimodal NPC communication channels
33
  self.channels: dict[str, ChatChannel | VoiceChannel | DocumentChannel] = {
@@ -36,20 +147,27 @@ class NPCManager:
36
  "document": DocumentChannel(),
37
  }
38
 
 
 
 
 
39
  async def start(
40
  self,
41
  snapshot: SnapshotSpec,
42
- containers: ContainerSet,
43
  ) -> None:
44
  """Start NPC traffic generators.
45
 
46
  Level 0: shell scripts (http, ssh, db traffic loops).
47
  Level 1: LLM NPC agents (deferred to npc_agent.py).
 
 
48
  """
49
  if self._running:
50
  await self.stop()
51
 
52
  self._running = True
 
53
  npc_cfg = snapshot.npc_traffic
54
 
55
  # Re-initialise channels for the new episode
@@ -76,8 +194,20 @@ class NPCManager:
76
  len(snapshot.npc_personas),
77
  )
78
 
79
- # Determine which scripts to run
80
- scripts = npc_cfg.scripts or ["http_traffic.sh", "db_traffic.sh"]
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  for script_name in scripts:
83
  script_path = _SCRIPT_DIR / script_name
@@ -85,38 +215,59 @@ class NPCManager:
85
  logger.warning("NPC script not found: %s", script_path)
86
  continue
87
 
88
- # Build environment for the script
89
- env = {
90
- "WEB_HOST": "web",
91
- "DB_HOST": "db",
92
- "RATE_LAMBDA": str(int(npc_cfg.rate_lambda)),
93
- }
94
-
95
- logger.info("Starting NPC script: %s (rate=%s)", script_name, npc_cfg.rate_lambda)
96
 
97
- try:
98
- proc = await asyncio.create_subprocess_exec(
99
- "bash",
100
- str(script_path),
101
- stdout=asyncio.subprocess.DEVNULL,
102
- stderr=asyncio.subprocess.DEVNULL,
103
- env=env,
104
- )
105
- self._processes.append(proc)
106
- except OSError as exc:
107
- logger.warning("Failed to start NPC script %s: %s", script_name, exc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Level 1 LLM NPCs -- start async agent loops if personas are present
110
- if npc_cfg.level >= 1 and snapshot.npc_personas:
111
  from open_range.builder.npc.npc_agent import LLMNPCAgent
112
 
113
- agent = LLMNPCAgent()
114
  for persona in snapshot.npc_personas:
 
115
  task = asyncio.create_task(
116
  agent.run_loop(persona, containers),
117
  name=f"npc_{persona.name}",
118
  )
119
  self._tasks.append(task)
 
120
  logger.info("Started LLM NPC agent: %s", persona.name)
121
 
122
  async def stop(self) -> None:
@@ -127,8 +278,9 @@ class NPCManager:
127
  if self._tasks:
128
  await asyncio.gather(*self._tasks, return_exceptions=True)
129
  self._tasks.clear()
 
130
 
131
- # Terminate shell script processes
132
  for proc in self._processes:
133
  try:
134
  proc.terminate()
@@ -140,6 +292,19 @@ class NPCManager:
140
  pass
141
  self._processes.clear()
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # Clear channel state
144
  for ch in self.channels.values():
145
  ch.clear()
@@ -147,6 +312,123 @@ class NPCManager:
147
  self._running = False
148
  logger.info("All NPC traffic stopped.")
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  @property
151
  def running(self) -> bool:
152
  """Whether NPC traffic is currently active."""
 
4
  LLM-driven NPC agents for a given snapshot. Multimodal NPC channels
5
  (chat, voice, document) are initialised at start and their activity logs
6
  are available for SIEM consumption.
7
+
8
+ In **mock mode** (``mock_mode=True``), no Docker exec or LLM calls are
9
+ made. Only synthetic chat traffic is generated from the
10
+ ``chat_traffic`` module, so unit tests can exercise the NPC pipeline
11
+ without infrastructure.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import asyncio
17
+ import base64
18
  import logging
19
  from pathlib import Path
20
  from typing import Any
 
26
 
27
  _SCRIPT_DIR = Path(__file__).parent
28
 
29
+ # ---------------------------------------------------------------------------
30
+ # Service keyword mappings used to match script prefixes to topology hosts
31
+ # and to resolve well-known env-var roles from service lists.
32
+ # ---------------------------------------------------------------------------
33
+
34
+ # Map a script filename keyword to service keywords that indicate a host
35
+ # can run that script. Order matters for priority within each entry.
36
+ _SCRIPT_SERVICE_KEYWORDS: dict[str, list[str]] = {
37
+ "http": ["nginx", "apache", "httpd", "web", "php-fpm"],
38
+ "db": ["mysql", "mariadb", "postgres", "postgresql", "mongodb", "redis"],
39
+ "ssh": ["nmap", "hydra", "nikto", "ssh-client", "attacker", "sshd"],
40
+ "smtp": ["postfix", "sendmail", "exim", "dovecot", "mail"],
41
+ }
42
+
43
+ # Map an env-var role (e.g. WEB_HOST) to service keywords that identify the
44
+ # host fulfilling that role.
45
+ _ROLE_SERVICE_KEYWORDS: dict[str, list[str]] = {
46
+ "WEB_HOST": ["nginx", "apache", "httpd", "web", "php-fpm"],
47
+ "DB_HOST": ["mysql", "mariadb", "postgres", "postgresql", "mongodb"],
48
+ "MAIL_HOST": ["postfix", "sendmail", "dovecot", "mail"],
49
+ "LDAP_HOST": ["openldap", "ldap", "slapd"],
50
+ "SIEM_HOST": ["rsyslog", "elasticsearch", "siem", "splunk"],
51
+ }
52
+
53
+
54
+ def _hosts_from_topology(topology: dict[str, Any]) -> list[dict[str, Any]]:
55
+ """Extract the list of host dicts from *topology*, tolerating missing keys."""
56
+ return topology.get("hosts") or []
57
+
58
+
59
+ def _host_matches_keywords(host: dict[str, Any], keywords: list[str]) -> bool:
60
+ """Return True if the host's name or any of its services match *keywords*."""
61
+ host_name = (host.get("name") or "").lower()
62
+ services = [s.lower() for s in (host.get("services") or [])]
63
+ for kw in keywords:
64
+ kw_lower = kw.lower()
65
+ if kw_lower in host_name or any(kw_lower in svc for svc in services):
66
+ return True
67
+ return False
68
+
69
+
70
+ def _container_for_script(script_name: str, topology: dict[str, Any]) -> str:
71
+ """Determine which container a script should run inside.
72
+
73
+ Matches the script filename against service keywords in the topology
74
+ hosts. Falls back to the first host if nothing matches.
75
+ """
76
+ hosts = _hosts_from_topology(topology)
77
+ if not hosts:
78
+ return "web" # legacy fallback when topology is empty
79
+
80
+ for prefix, keywords in _SCRIPT_SERVICE_KEYWORDS.items():
81
+ if prefix in script_name.lower():
82
+ for host in hosts:
83
+ if _host_matches_keywords(host, keywords):
84
+ return host["name"]
85
+ break # prefix matched but no host found; fall through
86
+
87
+ # Default: first host in topology
88
+ return hosts[0].get("name", "web")
89
+
90
+
91
+ def _resolve_env_vars(topology: dict[str, Any], rate_lambda: float) -> dict[str, str]:
92
+ """Build environment variables by resolving roles from the topology.
93
+
94
+ Instead of hardcoding ``WEB_HOST=web``, this finds the host whose
95
+ services list contains web/nginx/etc and maps the role to its name.
96
+ """
97
+ hosts = _hosts_from_topology(topology)
98
+ env: dict[str, str] = {"RATE_LAMBDA": str(int(rate_lambda))}
99
+
100
+ for role, keywords in _ROLE_SERVICE_KEYWORDS.items():
101
+ for host in hosts:
102
+ if _host_matches_keywords(host, keywords):
103
+ env[role] = host["name"]
104
+ break
105
+
106
+ return env
107
+
108
+
109
+ def _derive_scripts_from_topology(topology: dict[str, Any]) -> list[str]:
110
+ """Derive available NPC scripts from topology services.
111
+
112
+ Scans the topology hosts and checks which script prefixes have a
113
+ matching host. Only returns scripts that actually exist on disk.
114
+ """
115
+ hosts = _hosts_from_topology(topology)
116
+ scripts: list[str] = []
117
+
118
+ for prefix, keywords in _SCRIPT_SERVICE_KEYWORDS.items():
119
+ for host in hosts:
120
+ if _host_matches_keywords(host, keywords):
121
+ candidate = f"{prefix}_traffic.sh"
122
+ if (_SCRIPT_DIR / candidate).exists():
123
+ scripts.append(candidate)
124
+ break # one match per prefix is enough
125
+
126
+ return scripts
127
+
128
 
129
  class NPCManager:
130
  """Start and stop NPC background traffic for a snapshot."""
131
 
132
+ def __init__(self, mock_mode: bool = False) -> None:
133
+ self._mock_mode = mock_mode
134
  self._processes: list[asyncio.subprocess.Process] = []
135
  self._tasks: list[asyncio.Task[Any]] = []
136
  self._running = False
137
+ self._npc_agents: list[Any] = [] # LLMNPCAgent instances
138
+
139
+ # Containers where scripts were deployed (for cleanup)
140
+ self._script_containers: list[str] = []
141
+ self._containers: ContainerSet | None = None
142
 
143
  # Multimodal NPC communication channels
144
  self.channels: dict[str, ChatChannel | VoiceChannel | DocumentChannel] = {
 
147
  "document": DocumentChannel(),
148
  }
149
 
150
+ # -----------------------------------------------------------------
151
+ # Async start / stop (used when an event loop is available)
152
+ # -----------------------------------------------------------------
153
+
154
  async def start(
155
  self,
156
  snapshot: SnapshotSpec,
157
+ containers: ContainerSet | None = None,
158
  ) -> None:
159
  """Start NPC traffic generators.
160
 
161
  Level 0: shell scripts (http, ssh, db traffic loops).
162
  Level 1: LLM NPC agents (deferred to npc_agent.py).
163
+
164
+ In mock mode, only synthetic chat traffic is generated.
165
  """
166
  if self._running:
167
  await self.stop()
168
 
169
  self._running = True
170
+ self._containers = containers
171
  npc_cfg = snapshot.npc_traffic
172
 
173
  # Re-initialise channels for the new episode
 
194
  len(snapshot.npc_personas),
195
  )
196
 
197
+ # In mock mode, skip Docker exec and LLM agent loops
198
+ if self._mock_mode:
199
+ logger.info("NPC manager running in mock mode (no Docker/LLM)")
200
+ return
201
+
202
+ topology = snapshot.topology
203
+
204
+ # Determine which scripts to run -- derive from topology when
205
+ # the snapshot does not specify scripts explicitly.
206
+ scripts = npc_cfg.scripts or _derive_scripts_from_topology(topology)
207
+
208
+ # Resolve environment variables (WEB_HOST, DB_HOST, etc.) from
209
+ # the topology instead of hardcoding host names.
210
+ env_vars = _resolve_env_vars(topology, npc_cfg.rate_lambda)
211
 
212
  for script_name in scripts:
213
  script_path = _SCRIPT_DIR / script_name
 
215
  logger.warning("NPC script not found: %s", script_path)
216
  continue
217
 
218
+ container = _container_for_script(script_name, topology)
219
+ logger.info(
220
+ "Starting NPC script: %s in container %s (rate=%s)",
221
+ script_name, container, npc_cfg.rate_lambda,
222
+ )
 
 
 
223
 
224
+ if containers is not None:
225
+ # Run script inside the target container via docker exec
226
+ try:
227
+ script_content = script_path.read_text()
228
+ encoded = base64.b64encode(script_content.encode()).decode()
229
+ env_prefix = " ".join(
230
+ f"{k}={v}" for k, v in env_vars.items()
231
+ )
232
+ await containers.exec(
233
+ container,
234
+ f"echo {encoded} | base64 -d > /tmp/{script_name} "
235
+ f"&& chmod +x /tmp/{script_name} "
236
+ f"&& {env_prefix} nohup bash /tmp/{script_name} "
237
+ f"> /dev/null 2>&1 &",
238
+ )
239
+ self._script_containers.append(container)
240
+ except Exception as exc:
241
+ logger.warning(
242
+ "Failed to start NPC script %s in container %s: %s",
243
+ script_name, container, exc,
244
+ )
245
+ else:
246
+ # Fallback: run on host (original behavior)
247
+ try:
248
+ proc = await asyncio.create_subprocess_exec(
249
+ "bash",
250
+ str(script_path),
251
+ stdout=asyncio.subprocess.DEVNULL,
252
+ stderr=asyncio.subprocess.DEVNULL,
253
+ env=env_vars,
254
+ )
255
+ self._processes.append(proc)
256
+ except OSError as exc:
257
+ logger.warning("Failed to start NPC script %s: %s", script_name, exc)
258
 
259
  # Level 1 LLM NPCs -- start async agent loops if personas are present
260
+ if npc_cfg.level >= 1 and snapshot.npc_personas and containers is not None:
261
  from open_range.builder.npc.npc_agent import LLMNPCAgent
262
 
 
263
  for persona in snapshot.npc_personas:
264
+ agent = LLMNPCAgent()
265
  task = asyncio.create_task(
266
  agent.run_loop(persona, containers),
267
  name=f"npc_{persona.name}",
268
  )
269
  self._tasks.append(task)
270
+ self._npc_agents.append(agent)
271
  logger.info("Started LLM NPC agent: %s", persona.name)
272
 
273
  async def stop(self) -> None:
 
278
  if self._tasks:
279
  await asyncio.gather(*self._tasks, return_exceptions=True)
280
  self._tasks.clear()
281
+ self._npc_agents.clear()
282
 
283
+ # Terminate shell script processes (host-mode fallback)
284
  for proc in self._processes:
285
  try:
286
  proc.terminate()
 
292
  pass
293
  self._processes.clear()
294
 
295
+ # Kill background scripts inside containers
296
+ if self._containers is not None:
297
+ for container in set(self._script_containers):
298
+ try:
299
+ await self._containers.exec(
300
+ container,
301
+ "pkill -f 'npc.*traffic' 2>/dev/null || true",
302
+ )
303
+ except Exception:
304
+ pass
305
+ self._script_containers.clear()
306
+ self._containers = None
307
+
308
  # Clear channel state
309
  for ch in self.channels.values():
310
  ch.clear()
 
312
  self._running = False
313
  logger.info("All NPC traffic stopped.")
314
 
315
+ # -----------------------------------------------------------------
316
+ # Synchronous wrappers (for callers without an event loop)
317
+ # -----------------------------------------------------------------
318
+
319
+ def start_sync(self, snapshot: SnapshotSpec, containers: ContainerSet | None = None) -> None:
320
+ """Synchronous wrapper around :meth:`start`.
321
+
322
+ Uses the running event loop if available, otherwise creates a new one.
323
+ """
324
+ try:
325
+ loop = asyncio.get_running_loop()
326
+ except RuntimeError:
327
+ loop = None
328
+
329
+ if loop and loop.is_running():
330
+ # We're inside an async context -- schedule and return.
331
+ # Since we can't await here, run the coroutine eagerly using
332
+ # loop.run_until_complete which won't work if a loop is running.
333
+ # Instead, just call the sync-safe parts directly.
334
+ self._start_sync_inner(snapshot, containers)
335
+ else:
336
+ asyncio.run(self.start(snapshot, containers))
337
+
338
+ def stop_sync(self) -> None:
339
+ """Synchronous wrapper around :meth:`stop`."""
340
+ try:
341
+ loop = asyncio.get_running_loop()
342
+ except RuntimeError:
343
+ loop = None
344
+
345
+ if loop and loop.is_running():
346
+ self._stop_sync_inner()
347
+ else:
348
+ asyncio.run(self.stop())
349
+
350
+ def _start_sync_inner(self, snapshot: SnapshotSpec, containers: ContainerSet | None = None) -> None:
351
+ """Synchronous start that avoids asyncio for mock mode and chat traffic."""
352
+ if self._running:
353
+ self._stop_sync_inner()
354
+
355
+ self._running = True
356
+ self._containers = containers
357
+ npc_cfg = snapshot.npc_traffic
358
+
359
+ # Re-initialise channels for the new episode
360
+ self.channels = {
361
+ "chat": ChatChannel(),
362
+ "voice": VoiceChannel(),
363
+ "document": DocumentChannel(),
364
+ }
365
+
366
+ # Generate Level 0 chat traffic if personas are available
367
+ if snapshot.npc_personas and len(snapshot.npc_personas) >= 2:
368
+ from open_range.builder.npc.chat_traffic import generate_chat_traffic
369
+
370
+ chat_ch = self.channels["chat"]
371
+ assert isinstance(chat_ch, ChatChannel)
372
+ generate_chat_traffic(
373
+ personas=snapshot.npc_personas,
374
+ channel=chat_ch,
375
+ num_messages=10,
376
+ )
377
+ logger.info(
378
+ "Generated %d chat messages for %d personas",
379
+ len(chat_ch.get_channel_log()),
380
+ len(snapshot.npc_personas),
381
+ )
382
+
383
+ if self._mock_mode:
384
+ logger.info("NPC manager running in mock mode (no Docker/LLM)")
385
+ return
386
+
387
+ # In live mode with an active event loop, schedule async start
388
+ # for scripts and LLM agents. This is best-effort -- if it
389
+ # fails, the chat traffic is already available.
390
+ if containers is not None:
391
+ logger.info(
392
+ "NPC live scripts deferred (use async start() for full support)"
393
+ )
394
+
395
+ def _stop_sync_inner(self) -> None:
396
+ """Synchronous stop for mock mode (no async cleanup needed)."""
397
+ # Cancel any asyncio tasks that may exist
398
+ for task in self._tasks:
399
+ task.cancel()
400
+ self._tasks.clear()
401
+ self._npc_agents.clear()
402
+ self._processes.clear()
403
+ self._script_containers.clear()
404
+ self._containers = None
405
+
406
+ for ch in self.channels.values():
407
+ ch.clear()
408
+
409
+ self._running = False
410
+
411
+ # -----------------------------------------------------------------
412
+ # Traffic log for reward computation
413
+ # -----------------------------------------------------------------
414
+
415
+ def get_traffic_log(self) -> list[dict[str, Any]]:
416
+ """Return all NPC activity for reward computation.
417
+
418
+ Combines SIEM channel logs with LLM NPC agent action logs.
419
+ """
420
+ logs = self.get_siem_log()
421
+
422
+ # Append LLM NPC agent actions
423
+ for agent in self._npc_agents:
424
+ try:
425
+ logs.extend(agent.get_actions())
426
+ except Exception:
427
+ pass
428
+
429
+ logs.sort(key=lambda e: e.get("timestamp", 0))
430
+ return logs
431
+
432
  @property
433
  def running(self) -> bool:
434
  """Whether NPC traffic is currently active."""
src/open_range/server/environment.py CHANGED
@@ -111,6 +111,9 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
111
  self._exec_timeout = exec_timeout
112
  self._episode_start: float = 0.0
113
 
 
 
 
114
  # Reward instances -- imported lazily to avoid circular deps
115
  self._red_reward: Any = None
116
  self._blue_reward: Any = None
@@ -164,7 +167,10 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
164
  Tries multiple naming conventions:
165
  1. Snapshot compose config (if available)
166
  2. Docker Compose default: ``<project>-<service>-1``
167
- 3. Bare host name as fallback
 
 
 
168
  """
169
  if self._snapshot and self._snapshot.compose:
170
  services = self._snapshot.compose.get("services", {})
@@ -185,7 +191,20 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
185
  except Exception:
186
  pass
187
 
188
- return host
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  def _exec_via_subprocess(self, host: str, command: str, timeout: float = 30.0) -> tuple[str, str]:
191
  """Execute a command via local subprocess (all-in-one container mode).
@@ -228,12 +247,18 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
228
  timeout_s if timeout_s is not None else self._exec_timeout,
229
  )
230
 
231
- # Mock mode for unit tests (docker_available explicitly set to False)
 
 
 
232
  if self._docker_available is False:
233
- return (
234
- f"[mock] executed on {container_name}: {command}",
235
- "",
236
- )
 
 
 
237
 
238
  # Docker execution mode
239
  client = self._get_docker()
@@ -264,6 +289,29 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
264
  except Exception as exc:
265
  return "", f"Error executing command: {exc}"
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # -----------------------------------------------------------------
268
  # Snapshot applicator — deploys files, flags, and SQL to containers
269
  # -----------------------------------------------------------------
@@ -303,9 +351,10 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
303
  container_name,
304
  f"echo '{b64}' | base64 -d > /tmp/_snapshot.sql",
305
  )
 
306
  _, stderr = self._exec_in_container(
307
  container_name,
308
- "mysql -u root -pr00tP@ss! < /tmp/_snapshot.sql",
309
  )
310
  self._exec_in_container(
311
  container_name, "rm -f /tmp/_snapshot.sql"
@@ -372,9 +421,10 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
372
  tmp.write(content)
373
  tmp_path = tmp.name
374
  try:
 
375
  _, stderr = self._exec_via_subprocess(
376
  "db",
377
- f"mysql -u root -pr00tP@ss! < {shlex.quote(tmp_path)}",
378
  timeout=self._exec_timeout,
379
  )
380
  if stderr and "ERROR" in stderr:
@@ -409,6 +459,60 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
409
  deployed, len(snapshot.files),
410
  )
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  # -----------------------------------------------------------------
413
  # Snapshot selection
414
  # -----------------------------------------------------------------
@@ -432,20 +536,16 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
432
  self._snapshot_id = admitted.snapshot_id
433
  snap = admitted.snapshot
434
  else:
 
 
435
  self._snapshot_id = None
436
  snap = SnapshotSpec(
437
- topology={"hosts": []},
438
  flags=[],
439
  golden_path=[],
440
  task={
441
- "red_briefing": (
442
- "Target network detected. Begin reconnaissance and "
443
- "identify vulnerabilities. Capture all flags."
444
- ),
445
- "blue_briefing": (
446
- "Monitor SIEM for suspicious activity. Investigate "
447
- "alerts, patch vulnerabilities, and report findings."
448
- ),
449
  },
450
  )
451
 
@@ -686,13 +786,48 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
686
  def _resolve_target(self, action: RangeAction) -> str:
687
  """Determine which container to route the command to.
688
 
689
- For Red: commands run on the attacker container (or specified target).
690
- For Blue: commands run on the SIEM container.
 
 
 
 
691
  """
692
- if action.mode == "red":
693
- return self._container_name("attacker")
694
- else:
695
- return self._container_name("siem")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
  # -----------------------------------------------------------------
698
  # Core API
@@ -748,6 +883,9 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
748
  # Deploy snapshot artifacts to running containers
749
  self._apply_snapshot(self._snapshot)
750
 
 
 
 
751
  # Build initial briefing
752
  task = self._snapshot.task
753
  if isinstance(task, dict):
@@ -878,6 +1016,9 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
878
  # Check for pivot opportunities (#26)
879
  self._check_pivot(action, stdout)
880
 
 
 
 
881
  # Build observation
882
  obs = RangeObservation(
883
  stdout=stdout,
@@ -990,18 +1131,44 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
990
  # Alert system
991
  # -----------------------------------------------------------------
992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993
  def _get_pending_alerts(self) -> list[str]:
994
  """Return alerts from Red's recent actions for Blue to observe.
995
 
996
- In a full deployment, these would come from the SIEM container.
997
- In mock mode, we generate synthetic alerts from Red's action history.
 
 
998
  """
 
 
 
 
 
 
 
999
  alerts: list[str] = []
1000
  for record in self._red_history:
1001
  cmd = record.get("cmd_name", "")
1002
- if cmd in ("nmap", "nikto", "hydra", "sqlmap"):
1003
  alerts.append(
1004
- f"[IDS] Suspicious activity detected: {cmd} scan "
1005
  f"at step {record['step']}"
1006
  )
1007
  return alerts
@@ -1031,8 +1198,9 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
1031
  return list(self._npc_traffic_log)
1032
 
1033
  def close(self) -> None:
1034
- """Release resources (Docker client, episode state)."""
1035
  self._report_episode_result(completed=False)
 
1036
  if self._docker_client is not None:
1037
  try:
1038
  self._docker_client.close()
 
111
  self._exec_timeout = exec_timeout
112
  self._episode_start: float = 0.0
113
 
114
+ # NPC manager -- started/stopped with episode lifecycle
115
+ self._npc_manager: Any = None
116
+
117
  # Reward instances -- imported lazily to avoid circular deps
118
  self._red_reward: Any = None
119
  self._blue_reward: Any = None
 
167
  Tries multiple naming conventions:
168
  1. Snapshot compose config (if available)
169
  2. Docker Compose default: ``<project>-<service>-1``
170
+ 3. Raises ``RuntimeError`` if the host cannot be resolved
171
+
172
+ In unit-test mock mode (docker_available=False, execution_mode="docker"),
173
+ the bare hostname is returned as a fallback for test compatibility.
174
  """
175
  if self._snapshot and self._snapshot.compose:
176
  services = self._snapshot.compose.get("services", {})
 
191
  except Exception:
192
  pass
193
 
194
+ # In subprocess mode, commands run locally — the host name is only
195
+ # used for logging/routing, not for Docker container lookup.
196
+ if self._execution_mode == "subprocess":
197
+ return host
198
+
199
+ # In unit-test mock mode, return the bare hostname for compatibility
200
+ if self._docker_available is False and self._execution_mode == "docker":
201
+ return host
202
+
203
+ raise RuntimeError(
204
+ f"Cannot resolve container for host '{host}'. "
205
+ f"No compose config, no running container found, and no mock mode active. "
206
+ f"Ensure Docker is running or provide a snapshot with compose configuration."
207
+ )
208
 
209
  def _exec_via_subprocess(self, host: str, command: str, timeout: float = 30.0) -> tuple[str, str]:
210
  """Execute a command via local subprocess (all-in-one container mode).
 
247
  timeout_s if timeout_s is not None else self._exec_timeout,
248
  )
249
 
250
+ # Unit-test backward compatibility: when docker_available was explicitly
251
+ # set to False AND execution_mode resolved to "docker" (the auto path
252
+ # for tests), return synthetic output so tests can assert on container
253
+ # routing without real Docker.
254
  if self._docker_available is False:
255
+ if self._execution_mode == "docker":
256
+ return (
257
+ f"[mock] executed on {container_name}: {command}",
258
+ "",
259
+ )
260
+ # Production path: docker unavailable and mode is not subprocess
261
+ return "", f"Docker unavailable (execution_mode={self._execution_mode})"
262
 
263
  # Docker execution mode
264
  client = self._get_docker()
 
289
  except Exception as exc:
290
  return "", f"Error executing command: {exc}"
291
 
292
+ # -----------------------------------------------------------------
293
+ # Database credential helpers
294
+ # -----------------------------------------------------------------
295
+
296
+ def _db_credentials(self) -> str:
297
+ """Build MySQL CLI credential flags from the snapshot topology.
298
+
299
+ Looks up users in ``self._snapshot.topology["users"]`` whose ``hosts``
300
+ list contains ``"db"``. Returns ``-u <user> -p<password>`` for the
301
+ first match, or ``-u root`` (no password) if no user is defined.
302
+ """
303
+ if self._snapshot and isinstance(self._snapshot.topology, dict):
304
+ users = self._snapshot.topology.get("users", [])
305
+ for user in users:
306
+ hosts = user.get("hosts", [])
307
+ if "db" in hosts:
308
+ uname = user.get("username", "root")
309
+ pwd = user.get("password", "")
310
+ if pwd:
311
+ return f"-u {uname} -p{pwd}"
312
+ return f"-u {uname}"
313
+ return "-u root"
314
+
315
  # -----------------------------------------------------------------
316
  # Snapshot applicator — deploys files, flags, and SQL to containers
317
  # -----------------------------------------------------------------
 
351
  container_name,
352
  f"echo '{b64}' | base64 -d > /tmp/_snapshot.sql",
353
  )
354
+ db_creds = self._db_credentials()
355
  _, stderr = self._exec_in_container(
356
  container_name,
357
+ f"mysql {db_creds} < /tmp/_snapshot.sql",
358
  )
359
  self._exec_in_container(
360
  container_name, "rm -f /tmp/_snapshot.sql"
 
421
  tmp.write(content)
422
  tmp_path = tmp.name
423
  try:
424
+ db_creds = self._db_credentials()
425
  _, stderr = self._exec_via_subprocess(
426
  "db",
427
+ f"mysql {db_creds} < {shlex.quote(tmp_path)}",
428
  timeout=self._exec_timeout,
429
  )
430
  if stderr and "ERROR" in stderr:
 
459
  deployed, len(snapshot.files),
460
  )
461
 
462
+ # -----------------------------------------------------------------
463
+ # NPC lifecycle
464
+ # -----------------------------------------------------------------
465
+
466
+ def _start_npcs(self, snapshot: SnapshotSpec) -> None:
467
+ """Start NPC traffic generators for the current episode.
468
+
469
+ When execution_mode is not "docker" or Docker is unavailable, only
470
+ synthetic chat traffic is generated (no Docker exec or LLM calls).
471
+ In live mode, shell scripts run inside containers and LLM NPC
472
+ agents poll for stimuli.
473
+ """
474
+ try:
475
+ self._stop_npcs()
476
+
477
+ from open_range.builder.npc.npc_manager import NPCManager
478
+
479
+ mock = (self._docker_available is False) or (self._execution_mode != "docker")
480
+ mgr = NPCManager(mock_mode=mock)
481
+ self._npc_manager = mgr
482
+
483
+ # Start synchronously (NPCManager.start_sync handles mock vs live)
484
+ mgr.start_sync(snapshot)
485
+
486
+ # Seed the traffic log immediately from chat traffic generated at
487
+ # start time so that Blue has NPC noise from step 1.
488
+ self._refresh_npc_traffic_log()
489
+
490
+ logger.info(
491
+ "NPC manager started (mock=%s, personas=%d)",
492
+ mock,
493
+ len(snapshot.npc_personas or []),
494
+ )
495
+ except Exception as exc:
496
+ logger.warning("NPC startup failed (non-fatal): %s", exc)
497
+ self._npc_manager = None
498
+
499
+ def _stop_npcs(self) -> None:
500
+ """Stop any running NPC traffic generators."""
501
+ if self._npc_manager is not None:
502
+ try:
503
+ self._npc_manager.stop_sync()
504
+ except Exception as exc:
505
+ logger.debug("NPC stop error (ignored): %s", exc)
506
+ self._npc_manager = None
507
+
508
+ def _refresh_npc_traffic_log(self) -> None:
509
+ """Pull latest NPC activity from the manager into the traffic log."""
510
+ if self._npc_manager is not None:
511
+ try:
512
+ self._npc_traffic_log = self._npc_manager.get_traffic_log()
513
+ except Exception as exc:
514
+ logger.debug("NPC traffic log refresh failed: %s", exc)
515
+
516
  # -----------------------------------------------------------------
517
  # Snapshot selection
518
  # -----------------------------------------------------------------
 
536
  self._snapshot_id = admitted.snapshot_id
537
  snap = admitted.snapshot
538
  else:
539
+ # Backward-compatible minimal stub for tests, demos, and local
540
+ # mock-mode usage when a managed runtime is not configured.
541
  self._snapshot_id = None
542
  snap = SnapshotSpec(
543
+ topology={"hosts": ["attacker", "siem"]},
544
  flags=[],
545
  golden_path=[],
546
  task={
547
+ "red_briefing": "Test mode.",
548
+ "blue_briefing": "Test mode.",
 
 
 
 
 
 
549
  },
550
  )
551
 
 
786
  def _resolve_target(self, action: RangeAction) -> str:
787
  """Determine which container to route the command to.
788
 
789
+ Reads from the snapshot topology to find the appropriate host:
790
+ - Red: host with ``role: "attacker"`` or ``zone: "external"``.
791
+ - Blue: host with ``role: "siem"`` or ``zone: "management"``.
792
+
793
+ Falls back to ``"attacker"``/``"siem"`` if no snapshot is loaded
794
+ or no matching host is found in the topology.
795
  """
796
+ red_default = "attacker"
797
+ blue_default = "siem"
798
+
799
+ if self._snapshot and isinstance(self._snapshot.topology, dict):
800
+ hosts = self._snapshot.topology.get("hosts", [])
801
+
802
+ if action.mode == "red":
803
+ # Look for a host with role "attacker" or zone "external"
804
+ for h in hosts:
805
+ if isinstance(h, dict):
806
+ if h.get("role") == "attacker" or h.get("zone") == "external":
807
+ host_name = h.get("name", h.get("hostname", red_default))
808
+ return self._container_name(host_name)
809
+ # Fallback: check if "attacker" is in the hosts list (string entries)
810
+ for h in hosts:
811
+ if isinstance(h, str) and h == "attacker":
812
+ return self._container_name("attacker")
813
+ # Last resort
814
+ return self._container_name(red_default)
815
+ else:
816
+ # Look for a host with role "siem" or zone "management"
817
+ for h in hosts:
818
+ if isinstance(h, dict):
819
+ if h.get("role") == "siem" or h.get("zone") == "management":
820
+ host_name = h.get("name", h.get("hostname", blue_default))
821
+ return self._container_name(host_name)
822
+ # Fallback: check if "siem" is in the hosts list (string entries)
823
+ for h in hosts:
824
+ if isinstance(h, str) and h == "siem":
825
+ return self._container_name("siem")
826
+ # Last resort
827
+ return self._container_name(blue_default)
828
+
829
+ # No snapshot loaded — use hardcoded defaults as last resort
830
+ return self._container_name(red_default if action.mode == "red" else blue_default)
831
 
832
  # -----------------------------------------------------------------
833
  # Core API
 
883
  # Deploy snapshot artifacts to running containers
884
  self._apply_snapshot(self._snapshot)
885
 
886
+ # Start NPC traffic for this episode
887
+ self._start_npcs(self._snapshot)
888
+
889
  # Build initial briefing
890
  task = self._snapshot.task
891
  if isinstance(task, dict):
 
1016
  # Check for pivot opportunities (#26)
1017
  self._check_pivot(action, stdout)
1018
 
1019
+ # Refresh NPC traffic log for reward computation
1020
+ self._refresh_npc_traffic_log()
1021
+
1022
  # Build observation
1023
  obs = RangeObservation(
1024
  stdout=stdout,
 
1131
  # Alert system
1132
  # -----------------------------------------------------------------
1133
 
1134
+ def _query_siem_alerts(self) -> list[str]:
1135
+ """Query the SIEM host for real alert log entries.
1136
+
1137
+ Searches consolidated SIEM logs for error, warning, and attack
1138
+ indicators. Returns up to 20 recent matching lines.
1139
+ """
1140
+ siem_target = self._resolve_target(RangeAction(command="", mode="blue"))
1141
+ stdout, _ = self._exec_in_container(
1142
+ siem_target,
1143
+ "grep -i 'error\\|warning\\|suspicious\\|denied\\|attack\\|scan' "
1144
+ "/var/log/siem/consolidated/*.log 2>/dev/null | tail -20",
1145
+ timeout_s=5.0,
1146
+ )
1147
+ if stdout and stdout.strip():
1148
+ return [line for line in stdout.strip().splitlines() if line.strip()]
1149
+ return []
1150
+
1151
  def _get_pending_alerts(self) -> list[str]:
1152
  """Return alerts from Red's recent actions for Blue to observe.
1153
 
1154
+ In production (docker or subprocess mode with real infrastructure),
1155
+ queries the SIEM container for actual log-based alerts. Falls back
1156
+ to synthetic alerts derived from ALL Red actions when SIEM queries
1157
+ return nothing or in unit-test mock mode.
1158
  """
1159
+ # Try real SIEM query in non-mock modes
1160
+ if self._docker_available is not False or self._execution_mode == "subprocess":
1161
+ siem_alerts = self._query_siem_alerts()
1162
+ if siem_alerts:
1163
+ return siem_alerts
1164
+
1165
+ # Synthetic fallback: treat ALL Red actions as potential alerts
1166
  alerts: list[str] = []
1167
  for record in self._red_history:
1168
  cmd = record.get("cmd_name", "")
1169
+ if cmd:
1170
  alerts.append(
1171
+ f"[IDS] Suspicious activity detected: {cmd} "
1172
  f"at step {record['step']}"
1173
  )
1174
  return alerts
 
1198
  return list(self._npc_traffic_log)
1199
 
1200
  def close(self) -> None:
1201
+ """Release resources (Docker client, NPC manager, episode state)."""
1202
  self._report_episode_result(completed=False)
1203
+ self._stop_npcs()
1204
  if self._docker_client is not None:
1205
  try:
1206
  self._docker_client.close()
src/open_range/server/runtime.py CHANGED
@@ -371,6 +371,7 @@ class ManagedSnapshotRuntime:
371
  self._stop_event = threading.Event()
372
  self._started = False
373
  self._generation_counter = 0
 
374
 
375
  @classmethod
376
  def from_env(cls) -> "ManagedSnapshotRuntime":
@@ -452,10 +453,76 @@ class ManagedSnapshotRuntime:
452
  def acquire_snapshot(self, *, snapshot_id: str | None = None) -> RuntimeSnapshot:
453
  self.start()
454
  if snapshot_id:
455
- return self.get_snapshot(snapshot_id)
 
 
456
 
457
  stored = _run_coro_sync(self.store.select_entry(strategy=self.selection_strategy))
458
- return RuntimeSnapshot(snapshot_id=stored.snapshot_id, snapshot=stored.snapshot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  def get_snapshot(self, snapshot_id: str) -> RuntimeSnapshot:
461
  self.start()
@@ -542,6 +609,18 @@ class ManagedSnapshotRuntime:
542
 
543
  def _generate_and_store_snapshot(self) -> str:
544
  last_error: str | None = None
 
 
 
 
 
 
 
 
 
 
 
 
545
  for attempt in range(1, self.generation_retries + 1):
546
  context = self._build_context()
547
  parent_entry = self._select_parent_entry()
@@ -588,7 +667,20 @@ class ManagedSnapshotRuntime:
588
  def _build_context(self) -> BuildContext:
589
  seed = self._generation_counter
590
  self._generation_counter += 1
591
- tier = int(self.manifest.get("tier", 1) or 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  context = self.curriculum.build_context(seed=seed, tier=tier)
593
  context.episode_count = self.mutator.episode_count
594
  if self.live_admission_enabled:
@@ -741,6 +833,26 @@ class ManagedSnapshotRuntime:
741
  raise RuntimeError(f"no running containers found for project {project_name}")
742
  return ContainerSet(project_name=project_name, container_ids=container_ids)
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  def _deploy_snapshot_artifacts(
745
  self,
746
  snapshot: SnapshotSpec,
@@ -764,7 +876,11 @@ class ManagedSnapshotRuntime:
764
  sql_file.write_text(content, encoding="utf-8")
765
  try:
766
  await containers.cp("db", str(sql_file), "/tmp/_snapshot.sql")
767
- await containers.exec("db", "mysql -u root -pr00tP@ss! < /tmp/_snapshot.sql")
 
 
 
 
768
  await containers.exec("db", "rm -f /tmp/_snapshot.sql")
769
  finally:
770
  sql_file.unlink(missing_ok=True)
@@ -831,6 +947,10 @@ class ManagedSnapshotRuntime:
831
  snapshot_id: str,
832
  ) -> SnapshotSpec:
833
  rendered = snapshot.model_copy(deep=True)
 
 
 
 
834
 
835
  topology = dict(rendered.topology)
836
  topology["snapshot_id"] = snapshot_id
 
371
  self._stop_event = threading.Event()
372
  self._started = False
373
  self._generation_counter = 0
374
+ self._recent_acquisitions: list[str] = []
375
 
376
  @classmethod
377
  def from_env(cls) -> "ManagedSnapshotRuntime":
 
453
  def acquire_snapshot(self, *, snapshot_id: str | None = None) -> RuntimeSnapshot:
454
  self.start()
455
  if snapshot_id:
456
+ result = self.get_snapshot(snapshot_id)
457
+ self._track_acquisition(result.snapshot_id)
458
+ return result
459
 
460
  stored = _run_coro_sync(self.store.select_entry(strategy=self.selection_strategy))
461
+
462
+ # Diversity check: if candidate's vuln types completely overlap with the
463
+ # last 3 acquired snapshots, try to find an alternative.
464
+ if self._recent_acquisitions and not self._is_diverse(stored.snapshot):
465
+ alternative = self._find_diverse_snapshot(stored.snapshot_id)
466
+ if alternative is not None:
467
+ stored = alternative
468
+
469
+ result = RuntimeSnapshot(snapshot_id=stored.snapshot_id, snapshot=stored.snapshot)
470
+ self._track_acquisition(result.snapshot_id)
471
+ return result
472
+
473
+ def _track_acquisition(self, snapshot_id: str) -> None:
474
+ """Record a snapshot acquisition, keeping at most 10 entries."""
475
+ self._recent_acquisitions.append(snapshot_id)
476
+ if len(self._recent_acquisitions) > 10:
477
+ del self._recent_acquisitions[: len(self._recent_acquisitions) - 10]
478
+
479
+ def _recent_vuln_types(self) -> set[str]:
480
+ """Collect vuln types from the last 3 acquired snapshots."""
481
+ recent_ids = self._recent_acquisitions[-3:]
482
+ if not recent_ids:
483
+ return set()
484
+
485
+ all_meta = self.list_snapshots()
486
+ meta_by_id = {m.get("snapshot_id"): m for m in all_meta}
487
+ vuln_types: set[str] = set()
488
+ for sid in recent_ids:
489
+ meta = meta_by_id.get(sid)
490
+ if meta:
491
+ vuln_types.update(meta.get("vuln_classes", []))
492
+ return vuln_types
493
+
494
+ def _is_diverse(self, snapshot: SnapshotSpec) -> bool:
495
+ """Return True if *snapshot* has at least one vuln type not in recent history."""
496
+ recent = self._recent_vuln_types()
497
+ if not recent:
498
+ return True
499
+ candidate_vulns = {v.type for v in snapshot.truth_graph.vulns}
500
+ if not candidate_vulns:
501
+ return True
502
+ # Diverse if at least one vuln type is NOT in the recent set
503
+ return not candidate_vulns.issubset(recent)
504
+
505
+ def _find_diverse_snapshot(
506
+ self, exclude_id: str
507
+ ) -> "StoredSnapshot | None":
508
+ """Try to find a snapshot in the store whose vulns don't fully overlap."""
509
+ from open_range.builder.snapshot_store import StoredSnapshot
510
+
511
+ all_meta = self.list_snapshots()
512
+ recent = self._recent_vuln_types()
513
+
514
+ for meta in all_meta:
515
+ sid = meta.get("snapshot_id", "")
516
+ if sid == exclude_id:
517
+ continue
518
+ candidate_vulns = set(meta.get("vuln_classes", []))
519
+ if not candidate_vulns or not candidate_vulns.issubset(recent):
520
+ try:
521
+ entry = _run_coro_sync(self.store.get_entry(sid))
522
+ return entry
523
+ except Exception: # noqa: BLE001
524
+ continue
525
+ return None
526
 
527
  def get_snapshot(self, snapshot_id: str) -> RuntimeSnapshot:
528
  self.start()
 
609
 
610
  def _generate_and_store_snapshot(self) -> str:
611
  last_error: str | None = None
612
+ parent_snapshot: SnapshotSpec | None = None
613
+ parent_snapshot_id: str | None = None
614
+ existing = self.list_snapshots()
615
+ if existing:
616
+ parent_snapshot_id = str(existing[0].get("snapshot_id", "") or "")
617
+ if parent_snapshot_id:
618
+ try:
619
+ parent_snapshot = _run_coro_sync(self.store.get(parent_snapshot_id))
620
+ except FileNotFoundError:
621
+ parent_snapshot = None
622
+ parent_snapshot_id = None
623
+
624
  for attempt in range(1, self.generation_retries + 1):
625
  context = self._build_context()
626
  parent_entry = self._select_parent_entry()
 
667
  def _build_context(self) -> BuildContext:
668
  seed = self._generation_counter
669
  self._generation_counter += 1
670
+ base_tier = int(self.manifest.get("tier", 1) or 1)
671
+
672
+ # Curriculum progression: if the red agent has been solving at a high
673
+ # rate over the last 10 completed episodes, bump the effective tier.
674
+ tier = base_tier
675
+ completed = [o for o in self.curriculum.history if o.completed]
676
+ recent_completed = completed[-10:]
677
+ if len(recent_completed) >= 10:
678
+ recent_solve_rate = sum(
679
+ 1 for o in recent_completed if o.red_solved
680
+ ) / len(recent_completed)
681
+ if recent_solve_rate > 0.8:
682
+ tier = min(base_tier + 1, 5)
683
+
684
  context = self.curriculum.build_context(seed=seed, tier=tier)
685
  context.episode_count = self.mutator.episode_count
686
  if self.live_admission_enabled:
 
833
  raise RuntimeError(f"no running containers found for project {project_name}")
834
  return ContainerSet(project_name=project_name, container_ids=container_ids)
835
 
836
+ @staticmethod
837
+ def _mysql_credentials(snapshot: SnapshotSpec) -> str:
838
+ """Build MySQL CLI credential flags from the snapshot topology.
839
+
840
+ Searches ``topology["users"]`` for a user whose ``hosts`` list
841
+ contains ``"db"``. Returns ``-u <user> -p<password>`` for the
842
+ first match, or ``-u root`` (no password) as a safe fallback.
843
+ """
844
+ if isinstance(snapshot.topology, dict):
845
+ users = snapshot.topology.get("users", [])
846
+ for user in users:
847
+ hosts = user.get("hosts", [])
848
+ if "db" in hosts:
849
+ uname = user.get("username", "root")
850
+ pwd = user.get("password", "")
851
+ if pwd:
852
+ return f"-u {uname} -p{pwd}"
853
+ return f"-u {uname}"
854
+ return "-u root"
855
+
856
  def _deploy_snapshot_artifacts(
857
  self,
858
  snapshot: SnapshotSpec,
 
876
  sql_file.write_text(content, encoding="utf-8")
877
  try:
878
  await containers.cp("db", str(sql_file), "/tmp/_snapshot.sql")
879
+ mysql_creds = self._mysql_credentials(snapshot)
880
+ await containers.exec(
881
+ "db",
882
+ f"mysql {mysql_creds} < /tmp/_snapshot.sql",
883
+ )
884
  await containers.exec("db", "rm -f /tmp/_snapshot.sql")
885
  finally:
886
  sql_file.unlink(missing_ok=True)
 
947
  snapshot_id: str,
948
  ) -> SnapshotSpec:
949
  rendered = snapshot.model_copy(deep=True)
950
+ rendered.lineage = rendered.lineage.model_copy(deep=True)
951
+ rendered.lineage.snapshot_id = snapshot_id
952
+ if not rendered.lineage.root_snapshot_id:
953
+ rendered.lineage.root_snapshot_id = snapshot_id
954
 
955
  topology = dict(rendered.topology)
956
  topology["snapshot_id"] = snapshot_id
src/open_range/server/zone_router.py CHANGED
@@ -6,6 +6,9 @@ zones can reach which other zones on which ports.
6
 
7
  The agent experiences identical training signal to a
8
  multi-container setup with real iptables rules.
 
 
 
9
  """
10
 
11
  from __future__ import annotations
@@ -15,66 +18,54 @@ from typing import Any
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
- # Default Tier 1 zone routing table
19
- # Maps (from_zone, to_zone) -> set of allowed ports
20
- ZONE_ROUTES: dict[tuple[str, str], set[int]] = {
21
- ("external", "dmz"): {80, 443, 25},
22
- ("dmz", "internal"): {3306, 445},
23
- ("dmz", "management"): {389, 636},
24
- ("internal", "management"): {389},
25
- ("management", "dmz"): {514},
26
- ("management", "internal"): {514},
27
- }
28
-
29
- # Host -> zone mapping for Tier 1
30
- HOST_ZONES: dict[str, str] = {
31
- "attacker": "external",
32
- "firewall": "external", # multi-homed but agent sees external
33
- "web": "dmz",
34
- "mail": "dmz",
35
- "db": "internal",
36
- "files": "internal",
37
- "ldap": "management",
38
- "siem": "management",
39
- }
40
-
41
- # Host -> localhost port mapping (all services on localhost in subprocess mode)
42
- HOST_PORTS: dict[str, dict[str, int]] = {
43
- "web": {"http": 80, "https": 443},
44
- "mail": {"smtp": 25},
45
- "db": {"mysql": 3306},
46
- "files": {"smb": 445},
47
- "ldap": {"ldap": 389, "ldaps": 636},
48
- "siem": {"syslog": 514},
49
- }
50
-
51
 
52
  @dataclass
53
  class ZoneRouter:
54
- """Enforces network zone routing policy."""
 
 
 
 
 
 
 
 
55
 
56
- routes: dict[tuple[str, str], set[int]] = field(default_factory=lambda: dict(ZONE_ROUTES))
57
- host_zones: dict[str, str] = field(default_factory=lambda: dict(HOST_ZONES))
 
58
 
59
  @classmethod
60
  def from_snapshot(cls, topology: dict[str, Any]) -> "ZoneRouter":
61
- """Build router from snapshot topology and firewall rules."""
 
 
 
 
 
 
 
 
 
 
 
 
62
  router = cls()
63
 
64
- # Override host_zones from topology
65
  for host in topology.get("hosts", []):
66
  if isinstance(host, dict):
67
  name = host.get("name", "")
68
- zone = host.get("zone", "")
69
- if name and zone:
70
  router.host_zones[name] = zone
71
  elif isinstance(host, str):
72
- pass # keep defaults
 
73
 
74
- # Override routes from firewall_rules
75
  rules = topology.get("firewall_rules", [])
76
  if rules:
77
- router.routes = {}
78
  for rule in rules:
79
  action = rule.get("action", "deny")
80
  if action != "allow":
@@ -85,9 +76,25 @@ class ZoneRouter:
85
  if from_z and to_z:
86
  key = (from_z, to_z)
87
  router.routes[key] = router.routes.get(key, set()) | ports
 
 
88
 
89
  return router
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def can_reach(self, from_zone: str, to_zone: str, port: int) -> bool:
92
  """Check if a connection from one zone to another on a port is allowed."""
93
  if from_zone == to_zone:
@@ -103,12 +110,14 @@ class ZoneRouter:
103
  """Check if from_host can access target_host on port.
104
 
105
  Returns (allowed, reason).
 
106
  """
107
  from_zone = self.get_zone(from_host)
108
  to_zone = self.get_zone(target_host)
109
 
110
  if from_zone == "unknown" or to_zone == "unknown":
111
- return True, "unknown zone, allowing" # permissive for unknown hosts
 
112
 
113
  if self.can_reach(from_zone, to_zone, port):
114
  logger.debug("ALLOW %s(%s) -> %s(%s):%d", from_host, from_zone, target_host, to_zone, port)
 
6
 
7
  The agent experiences identical training signal to a
8
  multi-container setup with real iptables rules.
9
+
10
+ All routing data comes from the snapshot/manifest topology.
11
+ No hardcoded infrastructure constants.
12
  """
13
 
14
  from __future__ import annotations
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @dataclass
23
  class ZoneRouter:
24
+ """Enforces network zone routing policy.
25
+
26
+ Must be constructed via ``from_snapshot()`` or ``from_manifest()``
27
+ to load topology-driven routes and host-zone mappings. The bare
28
+ constructor creates an empty (deny-all) router.
29
+ """
30
+
31
+ routes: dict[tuple[str, str], set[int]] = field(default_factory=dict)
32
+ host_zones: dict[str, str] = field(default_factory=dict)
33
 
34
+ # ------------------------------------------------------------------ #
35
+ # Constructors
36
+ # ------------------------------------------------------------------ #
37
 
38
  @classmethod
39
  def from_snapshot(cls, topology: dict[str, Any]) -> "ZoneRouter":
40
+ """Build router from snapshot topology and firewall rules.
41
+
42
+ This is the primary constructor. It reads ``hosts`` and
43
+ ``firewall_rules`` from the topology dict to populate
44
+ ``host_zones`` and ``routes``.
45
+
46
+ If ``firewall_rules`` is missing or empty, a permissive default
47
+ is generated: same-zone traffic is always allowed (handled by
48
+ ``can_reach``), and all cross-zone traffic is denied.
49
+
50
+ If a host entry lacks a ``zone`` field, its zone is inferred as
51
+ ``"unknown"``.
52
+ """
53
  router = cls()
54
 
55
+ # Build host_zones from topology hosts list
56
  for host in topology.get("hosts", []):
57
  if isinstance(host, dict):
58
  name = host.get("name", "")
59
+ zone = host.get("zone", "unknown")
60
+ if name:
61
  router.host_zones[name] = zone
62
  elif isinstance(host, str):
63
+ # String-only entries get zone inferred as "unknown"
64
+ router.host_zones[host] = "unknown"
65
 
66
+ # Build routes from firewall_rules
67
  rules = topology.get("firewall_rules", [])
68
  if rules:
 
69
  for rule in rules:
70
  action = rule.get("action", "deny")
71
  if action != "allow":
 
76
  if from_z and to_z:
77
  key = (from_z, to_z)
78
  router.routes[key] = router.routes.get(key, set()) | ports
79
+ # else: no firewall_rules → routes stays empty → cross-zone denied,
80
+ # same-zone allowed (handled by can_reach)
81
 
82
  return router
83
 
84
+ @classmethod
85
+ def from_manifest(cls, manifest: dict[str, Any]) -> "ZoneRouter":
86
+ """Build a ZoneRouter from a raw manifest dict.
87
+
88
+ Used during validation before a snapshot exists. Extracts
89
+ topology from the manifest and delegates to ``from_snapshot``.
90
+ """
91
+ topology = manifest.get("topology", manifest)
92
+ return cls.from_snapshot(topology)
93
+
94
+ # ------------------------------------------------------------------ #
95
+ # Query methods
96
+ # ------------------------------------------------------------------ #
97
+
98
  def can_reach(self, from_zone: str, to_zone: str, port: int) -> bool:
99
  """Check if a connection from one zone to another on a port is allowed."""
100
  if from_zone == to_zone:
 
110
  """Check if from_host can access target_host on port.
111
 
112
  Returns (allowed, reason).
113
+ Unknown zones are denied (fail-closed).
114
  """
115
  from_zone = self.get_zone(from_host)
116
  to_zone = self.get_zone(target_host)
117
 
118
  if from_zone == "unknown" or to_zone == "unknown":
119
+ unknown = from_zone if from_zone == "unknown" else to_zone
120
+ return False, f"unknown zone: {unknown}"
121
 
122
  if self.can_reach(from_zone, to_zone, port):
123
  logger.debug("ALLOW %s(%s) -> %s(%s):%d", from_host, from_zone, target_host, to_zone, port)