| """Integration tests — Task 10.""" |
|
|
| import os |
| import sys |
| import subprocess |
|
|
| import pytest |
|
|
| _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) |
| if _PROJECT_ROOT not in sys.path: |
| sys.path.insert(0, _PROJECT_ROOT) |
|
|
| from server.play_environment import CyberSOCEnvironment |
| from server.episode_sandbox import EpisodeTimeout |
| from server.graders import grade_episode |
| from models import SOCActionWrapper, RedActionWrapper |
|
|
|
|
| |
| |
| |
|
|
| def _first_host(env): |
| return next(iter(env._threat_graph.hosts), "WS-042") |
|
|
|
|
| def _first_alert(env): |
| return next(iter(env._threat_graph.alerts), None) |
|
|
|
|
| def _valid_action(action_type, **kwargs): |
| return SOCActionWrapper(type=action_type, **kwargs) |
|
|
|
|
| |
| |
| |
|
|
| def test_easy_episode_completes(): |
| env = CyberSOCEnvironment() |
| obs = env.reset(task_id="easy") |
| hostname = _first_host(env) |
| for _ in range(5): |
| obs = env.step(_valid_action("query_host", hostname=hostname)) |
| assert obs is not None |
|
|
|
|
| def test_medium_episode_with_all_10_actions(): |
| env = CyberSOCEnvironment() |
| obs = env.reset(task_id="medium") |
| hostname = _first_host(env) |
| ioc_value = next(iter(env._threat_graph.iocs), None) |
|
|
| |
| alerts = list(env._threat_graph.alerts.keys()) |
| if len(alerts) >= 2: |
| obs = env.step(_valid_action("correlate_alerts", alert_ids=alerts[:2])) |
|
|
| |
| obs = env.step(_valid_action("query_host", hostname=hostname)) |
| obs = env.step(_valid_action("run_forensics", hostname=hostname)) |
|
|
| if ioc_value: |
| obs = env.step(_valid_action("enrich_ioc", ioc_value=ioc_value, ioc_type="ip")) |
| obs = env.step(_valid_action("scan_host_vulnerabilities", hostname=hostname)) |
|
|
| |
| if ioc_value: |
| obs = env.step(_valid_action("block_ioc", ioc_value=ioc_value, ioc_type="ip")) |
| obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="nonexistent.exe")) |
| obs = env.step(_valid_action("isolate_segment", subnet="corporate", reason="test")) |
|
|
| |
| try: |
| obs = env.step(_valid_action("trigger_playbook", playbook_name="c2_disruption", target=hostname)) |
| except Exception: |
| pass |
|
|
| |
| obs = env.step(_valid_action( |
| "submit_containment_plan", |
| plan=[{"threat_id": "T1", "actions_taken": ["kill"], "root_cause": "malware", "confidence": 0.8}], |
| executive_summary="Contained ransomware.", |
| )) |
| assert obs is not None |
|
|
|
|
| def test_phase_violation_returns_error(): |
| """kill_process during triage should record negative reward (not crash).""" |
| env = CyberSOCEnvironment() |
| env.reset(task_id="easy") |
| hostname = _first_host(env) |
| |
| obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="fake.exe")) |
| assert obs is not None |
|
|
|
|
| def test_lateral_pivot_red_action(): |
| """LateralPivot RedActionWrapper creates a pivoted_from edge and a SIEM alert.""" |
| env = CyberSOCEnvironment(fsp_mode=True) |
| env.reset(task_id="hard") |
|
|
| |
| src = next( |
| (h for h, hd in env._host_index.items() if hd.get("status") == "compromised"), |
| None, |
| ) |
| dst = next( |
| (h for h, hd in env._host_index.items() |
| if hd.get("status") not in ("compromised", "isolated") and h != src), |
| None, |
| ) |
| if src is None or dst is None: |
| pytest.skip("No suitable host pair for lateral pivot test") |
|
|
| |
| env.step(_valid_action("query_host", hostname=src)) |
| assert env._state.active_turn == "red" |
|
|
| alerts_before = len(env._alert_queue) |
| env.step(RedActionWrapper(type="lateral_pivot", source_host=src, target_host=dst)) |
|
|
| pivot_edges = [e for e in env._threat_graph.edges if e.edge_type == "pivoted_from"] |
| assert len(pivot_edges) >= 1 |
| assert env._host_index[dst]["status"] == "compromised" |
| assert len(env._alert_queue) > alerts_before |
|
|
|
|
| def test_step_reward_accumulates(): |
| env = CyberSOCEnvironment() |
| env.reset(task_id="easy") |
| hostname = _first_host(env) |
|
|
| before = env._step_reward_total |
| |
| for h in list(env._threat_graph.hosts.keys())[:3]: |
| if h in env._host_index: |
| env.step(_valid_action("run_forensics", hostname=h)) |
|
|
| assert env._step_reward_total > before |
|
|
|
|
| def test_step_reward_idempotent(): |
| env = CyberSOCEnvironment() |
| env.reset(task_id="easy") |
| hostname = _first_host(env) |
|
|
| env.step(_valid_action("run_forensics", hostname=hostname)) |
| after_first = env._step_reward_total |
| env.step(_valid_action("run_forensics", hostname=hostname)) |
| after_second = env._step_reward_total |
|
|
| |
| step_reward_delta = after_second - after_first |
| |
| key = ("investigation", "run_forensics", hostname) |
| assert key in env._fired_step_rewards |
|
|
|
|
| def test_grader_returns_10_dim(): |
| env = CyberSOCEnvironment() |
| env.reset(task_id="easy") |
| hostname = _first_host(env) |
| env.step(_valid_action("query_host", hostname=hostname)) |
|
|
| result = grade_episode( |
| episode_actions=list(env._state.timeline), |
| final_plan=None, |
| graph=env._threat_graph, |
| task_def=env._task_def, |
| state=env._state, |
| ) |
| assert len(result["breakdown"]) == 10 |
|
|
|
|
| def test_sandbox_step_limit(): |
| from server.episode_sandbox import EpisodeSandbox, MAX_STEPS_PER_EPISODE |
|
|
| env = CyberSOCEnvironment() |
| env.reset(task_id="easy") |
| sb = EpisodeSandbox(env) |
|
|
| with pytest.raises(EpisodeTimeout): |
| sb.check_step_limit(MAX_STEPS_PER_EPISODE) |
|
|
|
|
| def test_openenv_validate_compatible(): |
| """Run openenv validate and assert it exits 0 (env is valid with adaptive=False).""" |
| python = os.path.join(_PROJECT_ROOT, "..", "venv", "Scripts", "python.exe") |
| python = os.path.abspath(python) |
| if not os.path.exists(python): |
| pytest.skip("venv python not found") |
|
|
| result = subprocess.run( |
| [python, "-c", |
| "from server.play_environment import CyberSOCEnvironment; " |
| "env = CyberSOCEnvironment(adaptive=False); " |
| "obs = env.reset(task_id='easy'); " |
| "print('OK', len(obs.alert_queue))"], |
| cwd=_PROJECT_ROOT, |
| capture_output=True, |
| text=True, |
| timeout=60, |
| ) |
| assert result.returncode == 0, f"stderr: {result.stderr}" |
| assert "OK" in result.stdout |
|
|