"""Tests for curriculum feedback wiring (#34). Verifies that CurriculumTracker.update_from_result() works correctly and that run_episode() feeds results into the tracker. """ import pytest from open_range.training.curriculum import CurriculumTracker class TestUpdateFromResult: """CurriculumTracker.update_from_result() parses episode results.""" def test_basic_update(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-001", "vuln_classes": ["sqli", "xss"], "red_solved": True, "blue_detected": False, "tier": 1, }) assert len(tracker.episode_history) == 1 assert tracker.vuln_stats["sqli"]["attempts"] == 1 assert tracker.vuln_stats["sqli"]["red_solves"] == 1 assert tracker.vuln_stats["xss"]["blue_detects"] == 0 def test_infer_red_solved_from_outcome(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-002", "vuln_classes": ["sqli"], "outcome": "red_win", "tier": 1, }) assert tracker.episode_history[-1]["red_solved"] is True def test_infer_red_solved_from_flags(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-003", "vuln_classes": ["idor"], "flags_found": ["FLAG{gotcha}"], "tier": 2, }) assert tracker.episode_history[-1]["red_solved"] is True def test_infer_blue_detected_from_outcome(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-004", "vuln_classes": ["xss"], "outcome": "blue_win", "tier": 1, }) assert tracker.episode_history[-1]["blue_detected"] is True def test_timeout_outcome(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-005", "vuln_classes": ["ssrf"], "outcome": "timeout", "tier": 1, }) ep = tracker.episode_history[-1] assert ep["red_solved"] is False assert ep["blue_detected"] is False def test_explicit_flags_override_inference(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-006", "vuln_classes": ["sqli"], "red_solved": False, "blue_detected": True, "outcome": "red_win", # Would infer True, but explicit False wins "tier": 1, }) ep = tracker.episode_history[-1] assert ep["red_solved"] is False assert ep["blue_detected"] is True def test_extra_metadata_passed_through(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "snap-007", "vuln_classes": ["weak_creds"], "red_solved": True, "blue_detected": False, "tier": 1, "steps": 42, "outcome": "red_win", "red_model": "gpt-4", "blue_model": "llama-3", }) ep = tracker.episode_history[-1] assert ep.get("steps") == 42 assert ep.get("outcome") == "red_win" def test_empty_result_defaults(self): tracker = CurriculumTracker() tracker.update_from_result({}) assert len(tracker.episode_history) == 1 ep = tracker.episode_history[-1] assert ep["red_solved"] is False assert ep["blue_detected"] is False assert ep["tier"] == 1 class TestCurriculumStatsUpdate: """Verify that update_from_result correctly updates aggregate stats.""" def test_vuln_stats_accumulate(self): tracker = CurriculumTracker() for i in range(5): tracker.update_from_result({ "snapshot_id": f"snap-{i}", "vuln_classes": ["sqli"], "red_solved": i % 2 == 0, # solved on 0, 2, 4 "blue_detected": i % 3 == 0, # detected on 0, 3 "tier": 1, }) assert tracker.vuln_stats["sqli"]["attempts"] == 5 assert tracker.vuln_stats["sqli"]["red_solves"] == 3 assert tracker.vuln_stats["sqli"]["blue_detects"] == 2 def test_tier_stats_accumulate(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "a", "vuln_classes": ["sqli"], "red_solved": True, "blue_detected": False, "tier": 2, }) tracker.update_from_result({ "snapshot_id": "b", "vuln_classes": ["xss"], "red_solved": False, "blue_detected": True, "tier": 2, }) assert tracker.tier_stats[2]["episodes"] == 2 assert tracker.tier_stats[2]["red_solves"] == 1 assert tracker.tier_stats[2]["blue_detects"] == 1 def test_build_context_after_updates(self): tracker = CurriculumTracker() for i in range(3): tracker.update_from_result({ "snapshot_id": f"s{i}", "vuln_classes": ["sqli"], "red_solved": True, "blue_detected": False, "tier": 1, }) ctx = tracker.get_build_context() assert ctx["episode_count"] == 3 assert ctx["red_solve_rate"] == 1.0 assert ctx["blue_detect_rate"] == 0.0 assert "sqli" in ctx["previous_vuln_classes"] def test_attack_surfaces_passed(self): tracker = CurriculumTracker() tracker.update_from_result({ "snapshot_id": "s1", "vuln_classes": ["sqli"], "red_solved": True, "blue_detected": False, "tier": 1, "attack_surfaces": ["/search?q="], }) ctx = tracker.get_build_context() assert "/search?q=" in ctx["recent_attack_surfaces"] class TestRunEpisodeCurriculumWiring: """run_episode() calls curriculum.update_from_result() when provided.""" def test_run_episode_updates_curriculum(self): from open_range.protocols import ( FlagSpec, SnapshotSpec, TaskSpec, TruthGraph, Vulnerability, ) from open_range.server.environment import RangeEnvironment from open_range.agents.episode import run_episode class ScriptedAgent: """Minimal agent that runs a fixed script.""" def __init__(self, commands): self._commands = list(commands) self._idx = 0 def reset(self, briefing, role): self._idx = 0 def act(self, observation): if self._idx < len(self._commands): cmd = self._commands[self._idx] self._idx += 1 return cmd return "noop" env = RangeEnvironment(docker_available=False, max_steps=4) snapshot = SnapshotSpec( topology={ "hosts": ["attacker", "web"], "tier": 1, }, flags=[FlagSpec(id="f1", value="FLAG{x}", path="/f.txt", host="web")], golden_path=[], truth_graph=TruthGraph( vulns=[Vulnerability(id="v1", type="sqli", host="web")], ), task=TaskSpec(red_briefing="Go.", blue_briefing="Watch."), ) env.reset(snapshot=snapshot) # Patch _select_snapshot to always return our snapshot env._select_snapshot = lambda **kw: snapshot red = ScriptedAgent(["submit_flag FLAG{x}", "noop"]) blue = ScriptedAgent(["submit_finding attack found", "noop"]) tracker = CurriculumTracker() result = run_episode(env, red, blue, max_steps=4, curriculum=tracker) assert len(tracker.episode_history) == 1 ep = tracker.episode_history[0] assert ep["red_solved"] is True # flag was captured -> red_win assert "sqli" in ep["vuln_classes"] def test_run_episode_without_curriculum(self): """run_episode still works when no curriculum is provided.""" from open_range.protocols import SnapshotSpec, TaskSpec from open_range.server.environment import RangeEnvironment from open_range.agents.episode import run_episode class NoopAgent: def reset(self, briefing, role): pass def act(self, observation): return "noop" env = RangeEnvironment(docker_available=False, max_steps=2) snapshot = SnapshotSpec( topology={"hosts": ["attacker", "siem"]}, flags=[], golden_path=[], task=TaskSpec(red_briefing="Test.", blue_briefing="Test."), ) env._select_snapshot = lambda **kw: snapshot result = run_episode(env, NoopAgent(), NoopAgent(), max_steps=2) assert result.outcome in ("red_win", "blue_win", "timeout")