File size: 7,120 Bytes
57e71f8 5719ec3 57e71f8 5719ec3 57e71f8 5719ec3 57e71f8 5719ec3 57e71f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | """Integration tests — Task 10."""
import os
import sys
import subprocess
import pytest
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from server.play_environment import CyberSOCEnvironment
from server.episode_sandbox import EpisodeTimeout
from server.graders import grade_episode
from models import SOCActionWrapper, RedActionWrapper
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _first_host(env):
return next(iter(env._threat_graph.hosts), "WS-042")
def _first_alert(env):
return next(iter(env._threat_graph.alerts), None)
def _valid_action(action_type, **kwargs):
return SOCActionWrapper(type=action_type, **kwargs)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_easy_episode_completes():
env = CyberSOCEnvironment()
obs = env.reset(task_id="easy")
hostname = _first_host(env)
for _ in range(5):
obs = env.step(_valid_action("query_host", hostname=hostname))
assert obs is not None
def test_medium_episode_with_all_10_actions():
env = CyberSOCEnvironment()
obs = env.reset(task_id="medium")
hostname = _first_host(env)
ioc_value = next(iter(env._threat_graph.iocs), None)
# Triage phase
alerts = list(env._threat_graph.alerts.keys())
if len(alerts) >= 2:
obs = env.step(_valid_action("correlate_alerts", alert_ids=alerts[:2]))
# Investigation phase
obs = env.step(_valid_action("query_host", hostname=hostname))
obs = env.step(_valid_action("run_forensics", hostname=hostname))
if ioc_value:
obs = env.step(_valid_action("enrich_ioc", ioc_value=ioc_value, ioc_type="ip"))
obs = env.step(_valid_action("scan_host_vulnerabilities", hostname=hostname))
# Remediation phase
if ioc_value:
obs = env.step(_valid_action("block_ioc", ioc_value=ioc_value, ioc_type="ip"))
obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="nonexistent.exe"))
obs = env.step(_valid_action("isolate_segment", subnet="corporate", reason="test"))
# Trigger playbook (prereqs may or may not be met — just must not crash)
try:
obs = env.step(_valid_action("trigger_playbook", playbook_name="c2_disruption", target=hostname))
except Exception:
pass
# Submit plan
obs = env.step(_valid_action(
"submit_containment_plan",
plan=[{"threat_id": "T1", "actions_taken": ["kill"], "root_cause": "malware", "confidence": 0.8}],
executive_summary="Contained ransomware.",
))
assert obs is not None
def test_phase_violation_returns_error():
"""kill_process during triage should record negative reward (not crash)."""
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
# This action is dispatched — may return low reward but must not raise
obs = env.step(_valid_action("kill_process", hostname=hostname, process_name="fake.exe"))
assert obs is not None
def test_lateral_pivot_red_action():
"""LateralPivot RedActionWrapper creates a pivoted_from edge and a SIEM alert."""
env = CyberSOCEnvironment(fsp_mode=True)
env.reset(task_id="hard")
# Find a compromised host to pivot from and a healthy one to pivot to
src = next(
(h for h, hd in env._host_index.items() if hd.get("status") == "compromised"),
None,
)
dst = next(
(h for h, hd in env._host_index.items()
if hd.get("status") not in ("compromised", "isolated") and h != src),
None,
)
if src is None or dst is None:
pytest.skip("No suitable host pair for lateral pivot test")
# Blue takes a PassTurn-equivalent (query) so active_turn flips to red
env.step(_valid_action("query_host", hostname=src))
assert env._state.active_turn == "red"
alerts_before = len(env._alert_queue)
env.step(RedActionWrapper(type="lateral_pivot", source_host=src, target_host=dst))
pivot_edges = [e for e in env._threat_graph.edges if e.edge_type == "pivoted_from"]
assert len(pivot_edges) >= 1
assert env._host_index[dst]["status"] == "compromised"
assert len(env._alert_queue) > alerts_before # SIEM alert generated
def test_step_reward_accumulates():
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
before = env._step_reward_total
# run_forensics in investigation phase earns +0.10
for h in list(env._threat_graph.hosts.keys())[:3]:
if h in env._host_index:
env.step(_valid_action("run_forensics", hostname=h))
assert env._step_reward_total > before
def test_step_reward_idempotent():
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
env.step(_valid_action("run_forensics", hostname=hostname))
after_first = env._step_reward_total
env.step(_valid_action("run_forensics", hostname=hostname))
after_second = env._step_reward_total
# Second call on same host earns 0 extra step reward (though may earn -0.02 from handler)
step_reward_delta = after_second - after_first
# The idempotent part: _get_step_reward returns 0 on second call for same triple
key = ("investigation", "run_forensics", hostname)
assert key in env._fired_step_rewards
def test_grader_returns_10_dim():
env = CyberSOCEnvironment()
env.reset(task_id="easy")
hostname = _first_host(env)
env.step(_valid_action("query_host", hostname=hostname))
result = grade_episode(
episode_actions=list(env._state.timeline),
final_plan=None,
graph=env._threat_graph,
task_def=env._task_def,
state=env._state,
)
assert len(result["breakdown"]) == 10
def test_sandbox_step_limit():
from server.episode_sandbox import EpisodeSandbox, MAX_STEPS_PER_EPISODE
env = CyberSOCEnvironment()
env.reset(task_id="easy")
sb = EpisodeSandbox(env)
with pytest.raises(EpisodeTimeout):
sb.check_step_limit(MAX_STEPS_PER_EPISODE)
def test_openenv_validate_compatible():
"""Run openenv validate and assert it exits 0 (env is valid with adaptive=False)."""
python = os.path.join(_PROJECT_ROOT, "..", "venv", "Scripts", "python.exe")
python = os.path.abspath(python)
if not os.path.exists(python):
pytest.skip("venv python not found")
result = subprocess.run(
[python, "-c",
"from server.play_environment import CyberSOCEnvironment; "
"env = CyberSOCEnvironment(adaptive=False); "
"obs = env.reset(task_id='easy'); "
"print('OK', len(obs.alert_queue))"],
cwd=_PROJECT_ROOT,
capture_output=True,
text=True,
timeout=60,
)
assert result.returncode == 0, f"stderr: {result.stderr}"
assert "OK" in result.stdout
|