Aaron Brown Claude Opus 4.6 commited on
Commit
769dd2e
·
1 Parent(s): 4a77f25

Add task engine, exposure policy, auth scenario, pivot mechanics, curriculum wiring

Browse files

Implements five GitHub issues:

- #17 Task engine: TaskType enum, TaskSpec milestones/success_conditions,
milestone checking in step(), milestones_completed in RangeState
- #18 Exposure policy: ExposurePolicy model (public/hidden/authenticated/
misconfigured), added to Host in manifests/schema.py
- #25 Auth scenario: auth/logout meta-commands in environment, session
tracking via active_sessions and auth_attempts in RangeState
- #26 Pivot mechanics: access_grants and pivot_history in RangeState,
credential leak detection in command output via _check_pivot()
- #34 Curriculum feedback: CurriculumTracker.update_from_result() method,
run_episode() wires results to tracker when provided

All 354 tests pass (311 existing + 43 new). No existing tests broken.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

manifests/schema.py CHANGED
@@ -217,6 +217,14 @@ class OperationalContext(BaseModel, extra="allow"):
217
  # Topology primitives
218
  # ---------------------------------------------------------------------------
219
 
 
 
 
 
 
 
 
 
220
  class Host(BaseModel):
221
  """A single host (container) in the range topology."""
222
 
@@ -242,6 +250,7 @@ class Host(BaseModel):
242
  default="ubuntu:22.04",
243
  description="Base OS image for the container",
244
  )
 
245
 
246
 
247
  class Network(BaseModel):
 
217
  # Topology primitives
218
  # ---------------------------------------------------------------------------
219
 
220
+ class ExposurePolicy(BaseModel):
221
+ """Per-host exposure configuration."""
222
+
223
+ level: Literal["public", "hidden", "authenticated", "misconfigured"] = "public"
224
+ auth_required: bool = False
225
+ notes: str = ""
226
+
227
+
228
  class Host(BaseModel):
229
  """A single host (container) in the range topology."""
230
 
 
250
  default="ubuntu:22.04",
251
  description="Base OS image for the container",
252
  )
253
+ exposure: ExposurePolicy = Field(default_factory=ExposurePolicy)
254
 
255
 
256
  class Network(BaseModel):
src/open_range/agents/episode.py CHANGED
@@ -14,6 +14,7 @@ from open_range.agents.protocol import EpisodeMetrics, EpisodeResult
14
 
15
  if TYPE_CHECKING:
16
  from open_range.agents.protocol import RangeAgent
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
@@ -80,6 +81,7 @@ def run_episode(
80
  max_steps: int = 100,
81
  red_model: str = "",
82
  blue_model: str = "",
 
83
  ) -> EpisodeResult:
84
  """Run one tandem Red + Blue episode.
85
 
@@ -175,4 +177,24 @@ def run_episode(
175
  total_flags,
176
  )
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  return result
 
14
 
15
  if TYPE_CHECKING:
16
  from open_range.agents.protocol import RangeAgent
17
+ from open_range.training.curriculum import CurriculumTracker
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
81
  max_steps: int = 100,
82
  red_model: str = "",
83
  blue_model: str = "",
84
+ curriculum: CurriculumTracker | None = None,
85
  ) -> EpisodeResult:
86
  """Run one tandem Red + Blue episode.
87
 
 
177
  total_flags,
178
  )
179
 
180
+ # Curriculum feedback wiring (#34)
181
+ if curriculum is not None:
182
+ # Extract vuln classes from snapshot truth graph if available
183
+ vuln_classes: list[str] = []
184
+ if snapshot and hasattr(snapshot, "truth_graph") and snapshot.truth_graph:
185
+ tg = snapshot.truth_graph
186
+ vulns = getattr(tg, "vulns", [])
187
+ vuln_classes = [getattr(v, "type", "") for v in vulns if getattr(v, "type", "")]
188
+
189
+ curriculum.update_from_result({
190
+ "snapshot_id": snapshot_id,
191
+ "vuln_classes": vuln_classes,
192
+ "outcome": outcome,
193
+ "flags_found": list(flags_found),
194
+ "steps": step,
195
+ "tier": tier,
196
+ "red_model": red_model or getattr(red, "model", ""),
197
+ "blue_model": blue_model or getattr(blue, "model", ""),
198
+ })
199
+
200
  return result
src/open_range/protocols.py CHANGED
@@ -8,6 +8,7 @@ Three pluggable infrastructure components:
8
 
9
  from __future__ import annotations
10
 
 
11
  from typing import Any, Literal, Protocol, runtime_checkable
12
 
13
  from pydantic import BaseModel, Field
@@ -123,11 +124,27 @@ class NPCTrafficSpec(BaseModel):
123
  scripts: list[str] = Field(default_factory=list)
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
126
  class TaskSpec(BaseModel):
127
  """Agent-facing task descriptions (no leakage of internals)."""
128
 
129
  red_briefing: str = ""
130
  blue_briefing: str = ""
 
 
 
 
 
131
 
132
 
133
  class SnapshotSpec(BaseModel):
 
8
 
9
  from __future__ import annotations
10
 
11
+ from enum import Enum
12
  from typing import Any, Literal, Protocol, runtime_checkable
13
 
14
  from pydantic import BaseModel, Field
 
124
  scripts: list[str] = Field(default_factory=list)
125
 
126
 
127
+ class TaskType(str, Enum):
128
+ """Types of tasks agents can be assigned."""
129
+
130
+ EXPLOIT = "exploit"
131
+ INVESTIGATE = "investigate"
132
+ PATCH = "patch"
133
+ REPORT = "report"
134
+ ENDPOINT_QUERY = "endpoint_query"
135
+ MULTI_STEP = "multi_step"
136
+
137
+
138
  class TaskSpec(BaseModel):
139
  """Agent-facing task descriptions (no leakage of internals)."""
140
 
141
  red_briefing: str = ""
142
  blue_briefing: str = ""
143
+ task_type: str = "exploit" # Use str not enum for flexibility
144
+ milestones: list[str] = Field(default_factory=list) # For multi_step tasks
145
+ success_conditions: list[dict[str, Any]] = Field(
146
+ default_factory=list,
147
+ ) # [{type: "flag", value: "..."}, {type: "endpoint", url: "...", expect: "..."}]
148
 
149
 
150
  class SnapshotSpec(BaseModel):
src/open_range/server/environment.py CHANGED
@@ -20,7 +20,7 @@ import time
20
  from typing import Any
21
  from uuid import uuid4
22
 
23
- from open_range.protocols import SnapshotSpec
24
 
25
  from open_range.server.models import RangeAction, RangeObservation, RangeState
26
 
@@ -40,7 +40,7 @@ except ImportError:
40
  _HAS_OPENENV = False
41
 
42
  # Meta-commands processed by the environment itself (not forwarded to containers)
43
- META_COMMANDS = {"submit_flag", "submit_evidence", "submit_finding"}
44
 
45
  # Maximum steps before forced termination
46
  DEFAULT_MAX_STEPS = 100
@@ -264,6 +264,152 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
264
  stdout="Finding submitted and recorded.",
265
  )
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # -----------------------------------------------------------------
268
  # Target resolution
269
  # -----------------------------------------------------------------
@@ -397,6 +543,18 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
397
  self._check_termination(obs)
398
  return obs
399
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  # Route to container
401
  target = self._resolve_target(action)
402
  timeout = timeout_s or self._exec_timeout
@@ -416,6 +574,14 @@ class RangeEnvironment(_BASE): # type: ignore[misc]
416
  else:
417
  self._blue_history.append(action_record)
418
 
 
 
 
 
 
 
 
 
419
  # Build observation
420
  obs = RangeObservation(
421
  stdout=stdout,
 
20
  from typing import Any
21
  from uuid import uuid4
22
 
23
+ from open_range.protocols import SnapshotSpec, TaskSpec
24
 
25
  from open_range.server.models import RangeAction, RangeObservation, RangeState
26
 
 
40
  _HAS_OPENENV = False
41
 
42
  # Meta-commands processed by the environment itself (not forwarded to containers)
43
+ META_COMMANDS = {"submit_flag", "submit_evidence", "submit_finding", "auth", "logout"}
44
 
45
  # Maximum steps before forced termination
46
  DEFAULT_MAX_STEPS = 100
 
264
  stdout="Finding submitted and recorded.",
265
  )
266
 
267
+ # -----------------------------------------------------------------
268
+ # Auth scenario (#25)
269
+ # -----------------------------------------------------------------
270
+
271
+ def _handle_auth(self, action: RangeAction) -> RangeObservation:
272
+ """Process an ``auth <host> <username> <password>`` command.
273
+
274
+ Checks credentials against the topology user list in the snapshot.
275
+ Successful auth is recorded in ``state.active_sessions``.
276
+ """
277
+ parts = action.command.strip().split()
278
+ if len(parts) < 4:
279
+ return RangeObservation(
280
+ stdout="",
281
+ stderr="Usage: auth <host> <username> <password>",
282
+ )
283
+ host = parts[1]
284
+ username = parts[2]
285
+ password = parts[3]
286
+
287
+ attempt = {
288
+ "step": self._state.step_count,
289
+ "host": host,
290
+ "username": username,
291
+ "success": False,
292
+ "time": time.time(),
293
+ }
294
+
295
+ # Lookup credentials in the snapshot topology
296
+ authenticated = False
297
+ if self._snapshot and isinstance(self._snapshot.topology, dict):
298
+ users = self._snapshot.topology.get("users", [])
299
+ for user in users:
300
+ if (
301
+ user.get("username") == username
302
+ and user.get("password") == password
303
+ and host in user.get("hosts", [])
304
+ ):
305
+ authenticated = True
306
+ break
307
+
308
+ attempt["success"] = authenticated
309
+ self._state.auth_attempts.append(attempt)
310
+
311
+ if authenticated:
312
+ self._state.active_sessions[host] = username
313
+ # Record access grant for pivot tracking
314
+ grant = f"{host}:shell"
315
+ if grant not in self._state.access_grants:
316
+ self._state.access_grants.append(grant)
317
+ return RangeObservation(
318
+ stdout=f"Authenticated as {username} on {host}.",
319
+ )
320
+ else:
321
+ return RangeObservation(
322
+ stdout="",
323
+ stderr=f"Authentication failed for {username} on {host}.",
324
+ )
325
+
326
+ def _handle_logout(self, action: RangeAction) -> RangeObservation:
327
+ """Process a ``logout <host>`` command."""
328
+ parts = action.command.strip().split()
329
+ if len(parts) < 2:
330
+ return RangeObservation(
331
+ stdout="",
332
+ stderr="Usage: logout <host>",
333
+ )
334
+ host = parts[1]
335
+
336
+ if host in self._state.active_sessions:
337
+ user = self._state.active_sessions.pop(host)
338
+ return RangeObservation(
339
+ stdout=f"Logged out {user} from {host}.",
340
+ )
341
+ else:
342
+ return RangeObservation(
343
+ stdout="",
344
+ stderr=f"No active session on {host}.",
345
+ )
346
+
347
+ # -----------------------------------------------------------------
348
+ # Milestone checking (#17 task engine)
349
+ # -----------------------------------------------------------------
350
+
351
+ def _check_milestone(self, output: str) -> str | None:
352
+ """Check if command output satisfies the next pending milestone.
353
+
354
+ Returns the milestone string if matched, None otherwise.
355
+ """
356
+ if not self._snapshot:
357
+ return None
358
+
359
+ task = self._snapshot.task
360
+ if isinstance(task, dict):
361
+ task_type = task.get("task_type", "exploit")
362
+ milestones = task.get("milestones", [])
363
+ elif isinstance(task, TaskSpec):
364
+ task_type = task.task_type
365
+ milestones = task.milestones
366
+ else:
367
+ return None
368
+
369
+ if task_type != "multi_step" or not milestones:
370
+ return None
371
+
372
+ # Check each incomplete milestone against the output
373
+ completed = set(self._state.milestones_completed)
374
+ for ms in milestones:
375
+ if ms not in completed and ms.lower() in output.lower():
376
+ return ms
377
+ return None
378
+
379
+ # -----------------------------------------------------------------
380
+ # Pivot mechanics (#26)
381
+ # -----------------------------------------------------------------
382
+
383
+ def _check_pivot(self, action: RangeAction, stdout: str) -> None:
384
+ """Detect credential or access token leaks in command output.
385
+
386
+ When output contains credentials that match the truth graph,
387
+ record an access grant and log the pivot event.
388
+ """
389
+ if not self._snapshot or not isinstance(self._snapshot.topology, dict):
390
+ return
391
+
392
+ users = self._snapshot.topology.get("users", [])
393
+ for user in users:
394
+ uname = user.get("username", "")
395
+ pwd = user.get("password", "")
396
+ if not uname or not pwd:
397
+ continue
398
+ # Check if credentials appear in the command output
399
+ if uname in stdout and pwd in stdout:
400
+ for host in user.get("hosts", []):
401
+ grant = f"{host}:credential"
402
+ if grant not in self._state.access_grants:
403
+ self._state.access_grants.append(grant)
404
+ # Determine source host from the action target
405
+ source = self._resolve_target(action)
406
+ self._state.pivot_history.append({
407
+ "from": source,
408
+ "to": host,
409
+ "via": "credential_reuse",
410
+ "username": uname,
411
+ })
412
+
413
  # -----------------------------------------------------------------
414
  # Target resolution
415
  # -----------------------------------------------------------------
 
543
  self._check_termination(obs)
544
  return obs
545
 
546
+ if cmd_name == "auth":
547
+ obs = self._handle_auth(action)
548
+ obs = self._apply_rewards(action, obs)
549
+ self._check_termination(obs)
550
+ return obs
551
+
552
+ if cmd_name == "logout":
553
+ obs = self._handle_logout(action)
554
+ obs = self._apply_rewards(action, obs)
555
+ self._check_termination(obs)
556
+ return obs
557
+
558
  # Route to container
559
  target = self._resolve_target(action)
560
  timeout = timeout_s or self._exec_timeout
 
574
  else:
575
  self._blue_history.append(action_record)
576
 
577
+ # Check for milestone completion (#17)
578
+ milestone = self._check_milestone(stdout)
579
+ if milestone and milestone not in self._state.milestones_completed:
580
+ self._state.milestones_completed.append(milestone)
581
+
582
+ # Check for pivot opportunities (#26)
583
+ self._check_pivot(action, stdout)
584
+
585
  # Build observation
586
  obs = RangeObservation(
587
  stdout=stdout,
src/open_range/server/models.py CHANGED
@@ -8,6 +8,8 @@ from __future__ import annotations
8
 
9
  from typing import Any, Literal
10
 
 
 
11
  try:
12
  from openenv.core.env_server.types import Action, Observation, State
13
  except ImportError:
@@ -45,3 +47,11 @@ class RangeState(State):
45
  flags_found: list[str] = []
46
  services_status: dict[str, Any] = {}
47
  tier: int = 1
 
 
 
 
 
 
 
 
 
8
 
9
  from typing import Any, Literal
10
 
11
+ from pydantic import Field
12
+
13
  try:
14
  from openenv.core.env_server.types import Action, Observation, State
15
  except ImportError:
 
47
  flags_found: list[str] = []
48
  services_status: dict[str, Any] = {}
49
  tier: int = 1
50
+ # Auth scenario (#25): session tracking
51
+ active_sessions: dict[str, str] = Field(default_factory=dict) # host -> username
52
+ auth_attempts: list[dict[str, Any]] = Field(default_factory=list)
53
+ # Pivot mechanics (#26): access and lateral movement tracking
54
+ access_grants: list[str] = Field(default_factory=list) # ["host:service", ...]
55
+ pivot_history: list[dict[str, str]] = Field(default_factory=list) # [{from: "web", to: "db", via: "credential_reuse"}]
56
+ # Task engine (#17): milestone tracking
57
+ milestones_completed: list[str] = Field(default_factory=list)
src/open_range/training/curriculum.py CHANGED
@@ -173,3 +173,56 @@ class CurriculumTracker:
173
  else:
174
  rates[vc] = 0.0
175
  return rates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  else:
174
  rates[vc] = 0.0
175
  return rates
176
+
177
+ def update_from_result(self, result: dict) -> None:
178
+ """Update curriculum stats from an episode result.
179
+
180
+ Accepts a dict with the following optional keys:
181
+
182
+ - ``snapshot_id`` (str): episode/snapshot identifier
183
+ - ``vuln_classes`` (list[str]): vulnerability classes in the episode
184
+ - ``red_solved`` (bool): whether Red captured a flag
185
+ - ``blue_detected`` (bool): whether Blue detected the attack
186
+ - ``tier`` (int): difficulty tier
187
+ - ``attack_surfaces`` (list[str]): injection points used
188
+ - ``outcome`` (str): episode outcome (``red_win``, ``blue_win``, ``timeout``)
189
+ - ``flags_found`` (list[str]): captured flags
190
+ - ``steps`` (int): total steps taken
191
+
192
+ If ``red_solved`` / ``blue_detected`` are not provided they are
193
+ inferred from ``outcome`` and ``flags_found``.
194
+ """
195
+ snapshot_id = result.get("snapshot_id", "")
196
+ vuln_classes = result.get("vuln_classes", [])
197
+ tier = result.get("tier", 1)
198
+ attack_surfaces = result.get("attack_surfaces", [])
199
+
200
+ # Infer solve/detect status if not explicitly provided
201
+ if "red_solved" in result:
202
+ red_solved = bool(result["red_solved"])
203
+ else:
204
+ outcome = result.get("outcome", "")
205
+ flags = result.get("flags_found", [])
206
+ red_solved = outcome == "red_win" or bool(flags)
207
+
208
+ if "blue_detected" in result:
209
+ blue_detected = bool(result["blue_detected"])
210
+ else:
211
+ blue_detected = result.get("outcome", "") == "blue_win"
212
+
213
+ # Collect extra metadata
214
+ extra_keys = {
215
+ "outcome", "flags_found", "steps",
216
+ "red_model", "blue_model",
217
+ }
218
+ extra = {k: result[k] for k in extra_keys if k in result}
219
+
220
+ self.record_episode(
221
+ snapshot_id=snapshot_id,
222
+ vuln_classes=vuln_classes,
223
+ red_solved=red_solved,
224
+ blue_detected=blue_detected,
225
+ tier=tier,
226
+ attack_surfaces=attack_surfaces,
227
+ extra=extra if extra else None,
228
+ )
tests/test_curriculum_integration.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for curriculum feedback wiring (#34).
2
+
3
+ Verifies that CurriculumTracker.update_from_result() works correctly
4
+ and that run_episode() feeds results into the tracker.
5
+ """
6
+
7
+ import pytest
8
+
9
+ from open_range.training.curriculum import CurriculumTracker
10
+
11
+
12
+ class TestUpdateFromResult:
13
+ """CurriculumTracker.update_from_result() parses episode results."""
14
+
15
+ def test_basic_update(self):
16
+ tracker = CurriculumTracker()
17
+ tracker.update_from_result({
18
+ "snapshot_id": "snap-001",
19
+ "vuln_classes": ["sqli", "xss"],
20
+ "red_solved": True,
21
+ "blue_detected": False,
22
+ "tier": 1,
23
+ })
24
+ assert len(tracker.episode_history) == 1
25
+ assert tracker.vuln_stats["sqli"]["attempts"] == 1
26
+ assert tracker.vuln_stats["sqli"]["red_solves"] == 1
27
+ assert tracker.vuln_stats["xss"]["blue_detects"] == 0
28
+
29
+ def test_infer_red_solved_from_outcome(self):
30
+ tracker = CurriculumTracker()
31
+ tracker.update_from_result({
32
+ "snapshot_id": "snap-002",
33
+ "vuln_classes": ["sqli"],
34
+ "outcome": "red_win",
35
+ "tier": 1,
36
+ })
37
+ assert tracker.episode_history[-1]["red_solved"] is True
38
+
39
+ def test_infer_red_solved_from_flags(self):
40
+ tracker = CurriculumTracker()
41
+ tracker.update_from_result({
42
+ "snapshot_id": "snap-003",
43
+ "vuln_classes": ["idor"],
44
+ "flags_found": ["FLAG{gotcha}"],
45
+ "tier": 2,
46
+ })
47
+ assert tracker.episode_history[-1]["red_solved"] is True
48
+
49
+ def test_infer_blue_detected_from_outcome(self):
50
+ tracker = CurriculumTracker()
51
+ tracker.update_from_result({
52
+ "snapshot_id": "snap-004",
53
+ "vuln_classes": ["xss"],
54
+ "outcome": "blue_win",
55
+ "tier": 1,
56
+ })
57
+ assert tracker.episode_history[-1]["blue_detected"] is True
58
+
59
+ def test_timeout_outcome(self):
60
+ tracker = CurriculumTracker()
61
+ tracker.update_from_result({
62
+ "snapshot_id": "snap-005",
63
+ "vuln_classes": ["ssrf"],
64
+ "outcome": "timeout",
65
+ "tier": 1,
66
+ })
67
+ ep = tracker.episode_history[-1]
68
+ assert ep["red_solved"] is False
69
+ assert ep["blue_detected"] is False
70
+
71
+ def test_explicit_flags_override_inference(self):
72
+ tracker = CurriculumTracker()
73
+ tracker.update_from_result({
74
+ "snapshot_id": "snap-006",
75
+ "vuln_classes": ["sqli"],
76
+ "red_solved": False,
77
+ "blue_detected": True,
78
+ "outcome": "red_win", # Would infer True, but explicit False wins
79
+ "tier": 1,
80
+ })
81
+ ep = tracker.episode_history[-1]
82
+ assert ep["red_solved"] is False
83
+ assert ep["blue_detected"] is True
84
+
85
+ def test_extra_metadata_passed_through(self):
86
+ tracker = CurriculumTracker()
87
+ tracker.update_from_result({
88
+ "snapshot_id": "snap-007",
89
+ "vuln_classes": ["weak_creds"],
90
+ "red_solved": True,
91
+ "blue_detected": False,
92
+ "tier": 1,
93
+ "steps": 42,
94
+ "outcome": "red_win",
95
+ "red_model": "gpt-4",
96
+ "blue_model": "llama-3",
97
+ })
98
+ ep = tracker.episode_history[-1]
99
+ assert ep.get("steps") == 42
100
+ assert ep.get("outcome") == "red_win"
101
+
102
+ def test_empty_result_defaults(self):
103
+ tracker = CurriculumTracker()
104
+ tracker.update_from_result({})
105
+ assert len(tracker.episode_history) == 1
106
+ ep = tracker.episode_history[-1]
107
+ assert ep["red_solved"] is False
108
+ assert ep["blue_detected"] is False
109
+ assert ep["tier"] == 1
110
+
111
+
112
+ class TestCurriculumStatsUpdate:
113
+ """Verify that update_from_result correctly updates aggregate stats."""
114
+
115
+ def test_vuln_stats_accumulate(self):
116
+ tracker = CurriculumTracker()
117
+ for i in range(5):
118
+ tracker.update_from_result({
119
+ "snapshot_id": f"snap-{i}",
120
+ "vuln_classes": ["sqli"],
121
+ "red_solved": i % 2 == 0, # solved on 0, 2, 4
122
+ "blue_detected": i % 3 == 0, # detected on 0, 3
123
+ "tier": 1,
124
+ })
125
+ assert tracker.vuln_stats["sqli"]["attempts"] == 5
126
+ assert tracker.vuln_stats["sqli"]["red_solves"] == 3
127
+ assert tracker.vuln_stats["sqli"]["blue_detects"] == 2
128
+
129
+ def test_tier_stats_accumulate(self):
130
+ tracker = CurriculumTracker()
131
+ tracker.update_from_result({
132
+ "snapshot_id": "a",
133
+ "vuln_classes": ["sqli"],
134
+ "red_solved": True,
135
+ "blue_detected": False,
136
+ "tier": 2,
137
+ })
138
+ tracker.update_from_result({
139
+ "snapshot_id": "b",
140
+ "vuln_classes": ["xss"],
141
+ "red_solved": False,
142
+ "blue_detected": True,
143
+ "tier": 2,
144
+ })
145
+ assert tracker.tier_stats[2]["episodes"] == 2
146
+ assert tracker.tier_stats[2]["red_solves"] == 1
147
+ assert tracker.tier_stats[2]["blue_detects"] == 1
148
+
149
+ def test_build_context_after_updates(self):
150
+ tracker = CurriculumTracker()
151
+ for i in range(3):
152
+ tracker.update_from_result({
153
+ "snapshot_id": f"s{i}",
154
+ "vuln_classes": ["sqli"],
155
+ "red_solved": True,
156
+ "blue_detected": False,
157
+ "tier": 1,
158
+ })
159
+ ctx = tracker.get_build_context()
160
+ assert ctx["episode_count"] == 3
161
+ assert ctx["red_solve_rate"] == 1.0
162
+ assert ctx["blue_detect_rate"] == 0.0
163
+ assert "sqli" in ctx["previous_vuln_classes"]
164
+
165
+ def test_attack_surfaces_passed(self):
166
+ tracker = CurriculumTracker()
167
+ tracker.update_from_result({
168
+ "snapshot_id": "s1",
169
+ "vuln_classes": ["sqli"],
170
+ "red_solved": True,
171
+ "blue_detected": False,
172
+ "tier": 1,
173
+ "attack_surfaces": ["/search?q="],
174
+ })
175
+ ctx = tracker.get_build_context()
176
+ assert "/search?q=" in ctx["recent_attack_surfaces"]
177
+
178
+
179
+ class TestRunEpisodeCurriculumWiring:
180
+ """run_episode() calls curriculum.update_from_result() when provided."""
181
+
182
+ def test_run_episode_updates_curriculum(self):
183
+ from open_range.protocols import (
184
+ FlagSpec,
185
+ SnapshotSpec,
186
+ TaskSpec,
187
+ TruthGraph,
188
+ Vulnerability,
189
+ )
190
+ from open_range.server.environment import RangeEnvironment
191
+ from open_range.agents.episode import run_episode
192
+
193
+ class ScriptedAgent:
194
+ """Minimal agent that runs a fixed script."""
195
+
196
+ def __init__(self, commands):
197
+ self._commands = list(commands)
198
+ self._idx = 0
199
+
200
+ def reset(self, briefing, role):
201
+ self._idx = 0
202
+
203
+ def act(self, observation):
204
+ if self._idx < len(self._commands):
205
+ cmd = self._commands[self._idx]
206
+ self._idx += 1
207
+ return cmd
208
+ return "noop"
209
+
210
+ env = RangeEnvironment(docker_available=False, max_steps=4)
211
+ snapshot = SnapshotSpec(
212
+ topology={
213
+ "hosts": ["attacker", "web"],
214
+ "tier": 1,
215
+ },
216
+ flags=[FlagSpec(id="f1", value="FLAG{x}", path="/f.txt", host="web")],
217
+ golden_path=[],
218
+ truth_graph=TruthGraph(
219
+ vulns=[Vulnerability(id="v1", type="sqli", host="web")],
220
+ ),
221
+ task=TaskSpec(red_briefing="Go.", blue_briefing="Watch."),
222
+ )
223
+ env.reset(snapshot=snapshot)
224
+ # Patch _select_snapshot to always return our snapshot
225
+ env._select_snapshot = lambda **kw: snapshot
226
+
227
+ red = ScriptedAgent(["submit_flag FLAG{x}", "noop"])
228
+ blue = ScriptedAgent(["submit_finding attack found", "noop"])
229
+
230
+ tracker = CurriculumTracker()
231
+ result = run_episode(env, red, blue, max_steps=4, curriculum=tracker)
232
+
233
+ assert len(tracker.episode_history) == 1
234
+ ep = tracker.episode_history[0]
235
+ assert ep["red_solved"] is True # flag was captured -> red_win
236
+ assert "sqli" in ep["vuln_classes"]
237
+
238
+ def test_run_episode_without_curriculum(self):
239
+ """run_episode still works when no curriculum is provided."""
240
+ from open_range.protocols import SnapshotSpec, TaskSpec
241
+ from open_range.server.environment import RangeEnvironment
242
+ from open_range.agents.episode import run_episode
243
+
244
+ class NoopAgent:
245
+ def reset(self, briefing, role):
246
+ pass
247
+
248
+ def act(self, observation):
249
+ return "noop"
250
+
251
+ env = RangeEnvironment(docker_available=False, max_steps=2)
252
+ result = run_episode(env, NoopAgent(), NoopAgent(), max_steps=2)
253
+ assert result.outcome in ("red_win", "blue_win", "timeout")
tests/test_environment.py CHANGED
@@ -2,7 +2,14 @@
2
 
3
  import pytest
4
 
5
- from open_range.protocols import FlagSpec, GoldenPathStep, SnapshotSpec, TaskSpec, TruthGraph
 
 
 
 
 
 
 
6
  from open_range.server.environment import RangeEnvironment, _extract_command_name
7
  from open_range.server.models import RangeAction, RangeObservation, RangeState
8
 
@@ -181,3 +188,278 @@ class TestStateProperty:
181
  assert env.state.step_count == 0
182
  env.step(RangeAction(command="nmap -sV web", mode="red"))
183
  assert env.state.step_count == 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import pytest
4
 
5
+ from open_range.protocols import (
6
+ FlagSpec,
7
+ GoldenPathStep,
8
+ SnapshotSpec,
9
+ TaskSpec,
10
+ TruthGraph,
11
+ Vulnerability,
12
+ )
13
  from open_range.server.environment import RangeEnvironment, _extract_command_name
14
  from open_range.server.models import RangeAction, RangeObservation, RangeState
15
 
 
188
  assert env.state.step_count == 0
189
  env.step(RangeAction(command="nmap -sV web", mode="red"))
190
  assert env.state.step_count == 1
191
+
192
+
193
+ # -------------------------------------------------------------------
194
+ # Task engine (#17)
195
+ # -------------------------------------------------------------------
196
+
197
+
198
+ def _make_multistep_snapshot():
199
+ """Helper: snapshot with a multi_step task and milestones."""
200
+ return SnapshotSpec(
201
+ topology={
202
+ "hosts": ["attacker", "web", "db"],
203
+ "users": [
204
+ {"username": "admin", "password": "admin123", "hosts": ["web", "db"]},
205
+ ],
206
+ },
207
+ flags=[FlagSpec(id="f1", value="FLAG{ms}", path="/f.txt", host="db")],
208
+ golden_path=[],
209
+ task=TaskSpec(
210
+ red_briefing="Multi-step challenge.",
211
+ blue_briefing="Watch.",
212
+ task_type="multi_step",
213
+ milestones=["port scan complete", "credentials found", "database accessed"],
214
+ ),
215
+ )
216
+
217
+
218
+ class TestTaskEngine:
219
+ """Milestone checking for multi_step tasks (#17)."""
220
+
221
+ def test_milestone_detected_in_output(self):
222
+ env = RangeEnvironment(docker_available=False)
223
+ snapshot = _make_multistep_snapshot()
224
+ env.reset(snapshot=snapshot)
225
+
226
+ # Mock mode returns "[mock] executed on attacker: ..." which won't match.
227
+ # We need to check that _check_milestone works with the right output.
228
+ ms = env._check_milestone("Port scan complete -- found open ports")
229
+ assert ms == "port scan complete"
230
+
231
+ def test_milestone_not_duplicated(self):
232
+ env = RangeEnvironment(docker_available=False)
233
+ snapshot = _make_multistep_snapshot()
234
+ env.reset(snapshot=snapshot)
235
+
236
+ # Simulate first milestone completion
237
+ env._state.milestones_completed.append("port scan complete")
238
+ ms = env._check_milestone("Port scan complete again")
239
+ assert ms is None # Already completed
240
+
241
+ def test_milestone_returns_none_for_exploit_task(self):
242
+ env = RangeEnvironment(docker_available=False)
243
+ snapshot = SnapshotSpec(
244
+ topology={"hosts": ["attacker", "web"]},
245
+ flags=[],
246
+ golden_path=[],
247
+ task=TaskSpec(red_briefing="Go.", blue_briefing="Watch.", task_type="exploit"),
248
+ )
249
+ env.reset(snapshot=snapshot)
250
+ ms = env._check_milestone("anything here")
251
+ assert ms is None
252
+
253
+ def test_milestone_returns_none_for_no_match(self):
254
+ env = RangeEnvironment(docker_available=False)
255
+ snapshot = _make_multistep_snapshot()
256
+ env.reset(snapshot=snapshot)
257
+ ms = env._check_milestone("nothing relevant here")
258
+ assert ms is None
259
+
260
+ def test_milestones_tracked_in_state(self):
261
+ env = RangeEnvironment(docker_available=False)
262
+ snapshot = _make_multistep_snapshot()
263
+ env.reset(snapshot=snapshot)
264
+ assert env.state.milestones_completed == []
265
+
266
+ # Manually add a milestone (simulating what step() does)
267
+ env._state.milestones_completed.append("port scan complete")
268
+ assert env.state.milestones_completed == ["port scan complete"]
269
+
270
+ def test_task_type_field_on_task_spec(self):
271
+ ts = TaskSpec(task_type="multi_step", milestones=["a", "b"])
272
+ assert ts.task_type == "multi_step"
273
+ assert ts.milestones == ["a", "b"]
274
+
275
+ def test_success_conditions_on_task_spec(self):
276
+ ts = TaskSpec(
277
+ success_conditions=[
278
+ {"type": "flag", "value": "FLAG{x}"},
279
+ {"type": "endpoint", "url": "/api/data", "expect": "secret"},
280
+ ],
281
+ )
282
+ assert len(ts.success_conditions) == 2
283
+ assert ts.success_conditions[0]["type"] == "flag"
284
+
285
+
286
+ # -------------------------------------------------------------------
287
+ # Auth scenario (#25)
288
+ # -------------------------------------------------------------------
289
+
290
+
291
+ def _make_auth_snapshot():
292
+ """Helper: snapshot with users for auth testing."""
293
+ return SnapshotSpec(
294
+ topology={
295
+ "hosts": ["attacker", "web", "db"],
296
+ "users": [
297
+ {"username": "admin", "password": "admin123", "hosts": ["web", "db"]},
298
+ {"username": "guest", "password": "guest", "hosts": ["web"]},
299
+ ],
300
+ },
301
+ flags=[FlagSpec(id="f1", value="FLAG{auth}", path="/f.txt", host="db")],
302
+ golden_path=[],
303
+ task=TaskSpec(red_briefing="Auth challenge.", blue_briefing="Watch."),
304
+ )
305
+
306
+
307
+ class TestAuthScenario:
308
+ """Auth and logout commands update session tracking (#25)."""
309
+
310
+ def test_auth_success(self):
311
+ env = RangeEnvironment(docker_available=False)
312
+ env.reset(snapshot=_make_auth_snapshot())
313
+ obs = env.step(RangeAction(command="auth web admin admin123", mode="red"))
314
+ assert "Authenticated" in obs.stdout
315
+ assert env.state.active_sessions["web"] == "admin"
316
+
317
+ def test_auth_failure(self):
318
+ env = RangeEnvironment(docker_available=False)
319
+ env.reset(snapshot=_make_auth_snapshot())
320
+ obs = env.step(RangeAction(command="auth web admin wrongpass", mode="red"))
321
+ assert "failed" in obs.stderr.lower()
322
+ assert "web" not in env.state.active_sessions
323
+
324
+ def test_auth_wrong_host(self):
325
+ env = RangeEnvironment(docker_available=False)
326
+ env.reset(snapshot=_make_auth_snapshot())
327
+ obs = env.step(RangeAction(command="auth db guest guest", mode="red"))
328
+ # guest only has access to web, not db
329
+ assert "failed" in obs.stderr.lower()
330
+ assert "db" not in env.state.active_sessions
331
+
332
+ def test_auth_attempt_logged(self):
333
+ env = RangeEnvironment(docker_available=False)
334
+ env.reset(snapshot=_make_auth_snapshot())
335
+ env.step(RangeAction(command="auth web admin admin123", mode="red"))
336
+ assert len(env.state.auth_attempts) == 1
337
+ assert env.state.auth_attempts[0]["success"] is True
338
+
339
+ def test_auth_failure_logged(self):
340
+ env = RangeEnvironment(docker_available=False)
341
+ env.reset(snapshot=_make_auth_snapshot())
342
+ env.step(RangeAction(command="auth web admin wrong", mode="red"))
343
+ assert len(env.state.auth_attempts) == 1
344
+ assert env.state.auth_attempts[0]["success"] is False
345
+
346
+ def test_logout_success(self):
347
+ env = RangeEnvironment(docker_available=False)
348
+ env.reset(snapshot=_make_auth_snapshot())
349
+ env.step(RangeAction(command="auth web admin admin123", mode="red"))
350
+ assert "web" in env.state.active_sessions
351
+ obs = env.step(RangeAction(command="logout web", mode="red"))
352
+ assert "Logged out" in obs.stdout
353
+ assert "web" not in env.state.active_sessions
354
+
355
+ def test_logout_no_session(self):
356
+ env = RangeEnvironment(docker_available=False)
357
+ env.reset(snapshot=_make_auth_snapshot())
358
+ obs = env.step(RangeAction(command="logout web", mode="red"))
359
+ assert "No active session" in obs.stderr
360
+
361
+ def test_auth_missing_args(self):
362
+ env = RangeEnvironment(docker_available=False)
363
+ env.reset(snapshot=_make_auth_snapshot())
364
+ obs = env.step(RangeAction(command="auth web admin", mode="red"))
365
+ assert "Usage" in obs.stderr
366
+
367
+ def test_logout_missing_args(self):
368
+ env = RangeEnvironment(docker_available=False)
369
+ env.reset(snapshot=_make_auth_snapshot())
370
+ obs = env.step(RangeAction(command="logout", mode="red"))
371
+ assert "Usage" in obs.stderr
372
+
373
+ def test_auth_creates_access_grant(self):
374
+ env = RangeEnvironment(docker_available=False)
375
+ env.reset(snapshot=_make_auth_snapshot())
376
+ env.step(RangeAction(command="auth web admin admin123", mode="red"))
377
+ assert "web:shell" in env.state.access_grants
378
+
379
+
380
+ # -------------------------------------------------------------------
381
+ # Pivot mechanics (#26)
382
+ # -------------------------------------------------------------------
383
+
384
+
385
+ class TestPivotMechanics:
386
+ """Access grants and pivot tracking (#26)."""
387
+
388
+ def test_pivot_detected_from_credential_leak(self):
389
+ """When command output contains credentials matching the truth graph,
390
+ access_grants and pivot_history are updated."""
391
+ env = RangeEnvironment(docker_available=False)
392
+ snapshot = SnapshotSpec(
393
+ topology={
394
+ "hosts": ["attacker", "web", "db"],
395
+ "users": [
396
+ {"username": "dbadmin", "password": "s3cret!", "hosts": ["db"]},
397
+ ],
398
+ },
399
+ flags=[],
400
+ golden_path=[],
401
+ task=TaskSpec(red_briefing="Go.", blue_briefing="Watch."),
402
+ )
403
+ env.reset(snapshot=snapshot)
404
+
405
+ # Simulate checking pivot on command output that contains credentials
406
+ env._check_pivot(
407
+ RangeAction(command="cat /etc/app/config.ini", mode="red"),
408
+ "db_user = dbadmin\ndb_pass = s3cret!\nhost = db",
409
+ )
410
+ assert "db:credential" in env.state.access_grants
411
+ assert len(env.state.pivot_history) == 1
412
+ assert env.state.pivot_history[0]["to"] == "db"
413
+ assert env.state.pivot_history[0]["via"] == "credential_reuse"
414
+
415
+ def test_no_pivot_without_matching_creds(self):
416
+ env = RangeEnvironment(docker_available=False)
417
+ snapshot = SnapshotSpec(
418
+ topology={
419
+ "hosts": ["attacker", "web"],
420
+ "users": [
421
+ {"username": "admin", "password": "secret", "hosts": ["web"]},
422
+ ],
423
+ },
424
+ flags=[],
425
+ golden_path=[],
426
+ task=TaskSpec(red_briefing="Go.", blue_briefing="Watch."),
427
+ )
428
+ env.reset(snapshot=snapshot)
429
+ env._check_pivot(
430
+ RangeAction(command="ls", mode="red"),
431
+ "no credentials here",
432
+ )
433
+ assert env.state.access_grants == []
434
+ assert env.state.pivot_history == []
435
+
436
+ def test_pivot_not_duplicated(self):
437
+ env = RangeEnvironment(docker_available=False)
438
+ snapshot = SnapshotSpec(
439
+ topology={
440
+ "hosts": ["attacker", "web", "db"],
441
+ "users": [
442
+ {"username": "admin", "password": "pass", "hosts": ["db"]},
443
+ ],
444
+ },
445
+ flags=[],
446
+ golden_path=[],
447
+ task=TaskSpec(red_briefing="Go.", blue_briefing="Watch."),
448
+ )
449
+ env.reset(snapshot=snapshot)
450
+ action = RangeAction(command="cat config", mode="red")
451
+ env._check_pivot(action, "admin pass db")
452
+ env._check_pivot(action, "admin pass db")
453
+ # Should only appear once
454
+ assert env.state.access_grants.count("db:credential") == 1
455
+
456
+ def test_state_has_access_grants_field(self):
457
+ state = RangeState()
458
+ assert state.access_grants == []
459
+ assert state.pivot_history == []
460
+
461
+ def test_state_has_auth_fields(self):
462
+ state = RangeState()
463
+ assert state.active_sessions == {}
464
+ assert state.auth_attempts == []
465
+ assert state.milestones_completed == []
tests/test_manifest.py CHANGED
@@ -3,7 +3,7 @@
3
  import pytest
4
  from pydantic import ValidationError
5
 
6
- from manifests.schema import Manifest, load_manifest
7
 
8
 
9
  class TestManifestLoading:
@@ -141,3 +141,49 @@ class TestBugFamilies:
141
  def test_difficulty_min_le_max_vulns(self, manifests_dir):
142
  m = load_manifest(manifests_dir / "tier1_basic.yaml")
143
  assert m.difficulty.min_vulns <= m.difficulty.max_vulns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import pytest
4
  from pydantic import ValidationError
5
 
6
+ from manifests.schema import ExposurePolicy, Host, Manifest, load_manifest
7
 
8
 
9
  class TestManifestLoading:
 
141
  def test_difficulty_min_le_max_vulns(self, manifests_dir):
142
  m = load_manifest(manifests_dir / "tier1_basic.yaml")
143
  assert m.difficulty.min_vulns <= m.difficulty.max_vulns
144
+
145
+
146
+ class TestExposurePolicy:
147
+ """ExposurePolicy validates correctly (#18)."""
148
+
149
+ def test_default_exposure_policy(self):
150
+ ep = ExposurePolicy()
151
+ assert ep.level == "public"
152
+ assert ep.auth_required is False
153
+ assert ep.notes == ""
154
+
155
+ def test_custom_exposure_policy(self):
156
+ ep = ExposurePolicy(level="hidden", auth_required=True, notes="Internal only")
157
+ assert ep.level == "hidden"
158
+ assert ep.auth_required is True
159
+ assert ep.notes == "Internal only"
160
+
161
+ def test_invalid_level_rejected(self):
162
+ with pytest.raises(ValidationError):
163
+ ExposurePolicy(level="nonexistent")
164
+
165
+ def test_all_valid_levels(self):
166
+ for level in ("public", "hidden", "authenticated", "misconfigured"):
167
+ ep = ExposurePolicy(level=level)
168
+ assert ep.level == level
169
+
170
+ def test_host_with_exposure_field(self):
171
+ h = Host(
172
+ name="web",
173
+ zone="dmz",
174
+ exposure=ExposurePolicy(level="authenticated", auth_required=True),
175
+ )
176
+ assert h.exposure.level == "authenticated"
177
+ assert h.exposure.auth_required is True
178
+
179
+ def test_host_default_exposure(self):
180
+ h = Host(name="web", zone="dmz")
181
+ assert h.exposure.level == "public"
182
+ assert h.exposure.auth_required is False
183
+
184
+ def test_existing_manifests_still_load_with_exposure(self, manifests_dir):
185
+ """Adding the exposure field must not break existing manifests."""
186
+ m = load_manifest(manifests_dir / "tier1_basic.yaml")
187
+ # All hosts should have default exposure policies
188
+ for host in m.topology.hosts:
189
+ assert host.exposure.level == "public"