File size: 10,075 Bytes
c1a5935 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """Tests for the Security Audit Environment."""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from server.security_audit_env_environment import SecurityAuditEnvironment
from models import SecurityAuditAction, SecurityAuditObservation
class TestReset:
def test_clean_state(self):
env = SecurityAuditEnvironment()
obs = env.reset(scenario_id="easy")
assert obs.done is False and obs.reward == 0.0 and obs.discovered_hosts == []
assert obs.steps_remaining == 30 and "QuickLaunch" in obs.message
def test_clears_previous(self):
env = SecurityAuditEnvironment()
env.reset(scenario_id="easy")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
obs = env.reset(scenario_id="easy")
assert obs.discovered_hosts == [] and env._episode_reward == 0.0
def test_all_scenarios(self):
env = SecurityAuditEnvironment()
for sid, steps in [("easy", 30), ("medium", 50), ("hard", 60)]:
obs = env.reset(scenario_id=sid)
assert obs.steps_remaining == steps and obs.done is False
class TestActions:
def test_list_tools(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="list_tools"))
assert obs.available_tools is not None and len(obs.available_tools) == 10 and obs.reward == 0.0
def test_network_scan(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
assert len(obs.discovered_hosts) == 2 and obs.reward > 0
def test_missing_tool_name(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="use_tool"))
assert "Error" in obs.tool_output and obs.reward == -0.02
def test_submit_finding(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "SQL Injection in /api/login", "host": "10.0.1.10", "type": "SQL Injection", "severity": "Critical", "cwe": "CWE-89"}))
assert obs.findings_submitted == 1 and obs.reward > 0
def test_submit_missing_fields(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "Test"}))
assert obs.reward == -0.02 and "Missing" in obs.tool_output
def test_generate_report(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="generate_report"))
assert obs.done is True and "SECURITY AUDIT REPORT" in obs.tool_output and obs.metadata and "grades" in obs.metadata
class TestRewards:
def test_vary_by_action(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs1 = env.step(SecurityAuditAction(action_type="list_tools"))
obs2 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
assert obs1.reward == 0.0 and obs2.reward > 0.0
def test_difficulty_scaling(self):
rewards = {}
for sid in ["easy", "medium"]:
env = SecurityAuditEnvironment(); env.reset(scenario_id=sid)
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": f"10.0.{1 if sid=='easy' else 2}.0/24"}))
rewards[sid] = obs.reward
assert rewards["medium"] > rewards["easy"]
def test_redundant_penalty(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs1 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.1.10"}))
obs2 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.1.10"}))
assert obs2.reward < obs1.reward
class TestEpisodeBoundaries:
def test_step_limit(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = None
for _ in range(31):
obs = env.step(SecurityAuditAction(action_type="list_tools"))
if obs.done: break
assert obs.done is True
def test_generate_report_ends(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="generate_report"))
assert obs.done is True and obs.steps_remaining == 0
class TestProgressiveDiscovery:
def test_hidden_hosts_initially(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="medium")
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"}))
assert "10.0.2.10" in obs.discovered_hosts and "10.0.2.30" not in obs.discovered_hosts
def test_unlock_after_finding(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="medium")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"}))
env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "SSRF", "host": "10.0.2.10", "endpoint": "/api/upload/image", "type": "Server-Side Request Forgery (SSRF)", "severity": "High", "cwe": "CWE-918"}))
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"}))
assert "10.0.2.30" in obs.discovered_hosts
class TestHoneypot:
def test_penalty(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="hard")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.3.0/24"}))
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.3.99"}))
assert obs.reward < 0
class TestTruncation:
def test_report_not_truncated(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="generate_report"))
assert obs.done is True and obs.truncated is False
def test_step_limit_truncated(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = None
for _ in range(31):
obs = env.step(SecurityAuditAction(action_type="list_tools"))
if obs.done: break
assert obs.done is True and obs.truncated is True
class TestPhaseTracking:
def test_recon(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
assert env.step(SecurityAuditAction(action_type="list_tools")).current_phase == "reconnaissance"
def test_enumeration(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
assert env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})).current_phase == "enumeration"
def test_exploitation(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
assert env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "T", "host": "10.0.1.10", "severity": "H"})).current_phase == "exploitation"
def test_reporting(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
assert env.step(SecurityAuditAction(action_type="generate_report")).current_phase == "reporting"
class TestSeed:
def test_deterministic(self):
e1 = SecurityAuditEnvironment(); o1 = e1.reset(seed=42, scenario_id="easy")
e2 = SecurityAuditEnvironment(); o2 = e2.reset(seed=42, scenario_id="easy")
assert o1.message == o2.message
def test_no_seed(self):
env = SecurityAuditEnvironment()
assert env.reset(scenario_id="easy").steps_remaining == 30
class TestFindingRewardCap:
def test_diminishing(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
rewards = []
for i in range(6):
obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": f"Fake {i}", "host": "10.0.1.99", "severity": "Low"}))
rewards.append(obs.reward)
assert rewards[0] == 0.02 and rewards[5] == 0.0
class TestBaseline:
def test_easy_scores_high(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
for title, host, ep, typ, sev, cvss, cwe, owasp in [
("SQL Injection", "10.0.1.10", "/api/login", "SQL Injection", "Critical", 9.8, "CWE-89", "A03:2021 - Injection"),
("Broken Auth", "10.0.1.10", "/admin", "Broken Authentication", "High", 8.1, "CWE-798", "A07:2021 - Identification and Authentication Failures"),
("Misconfig", "10.0.1.20", None, "Security Misconfiguration", "Critical", 9.1, "CWE-200", "A05:2021 - Security Misconfiguration"),
]:
args = {"title": title, "host": host, "type": typ, "severity": sev, "cvss_score": cvss, "cwe": cwe, "owasp": owasp, "evidence": "Detailed evidence for " + title, "remediation": "Detailed remediation for " + title}
if ep: args["endpoint"] = ep
env.step(SecurityAuditAction(action_type="submit_finding", arguments=args))
obs = env.step(SecurityAuditAction(action_type="generate_report"))
g = obs.metadata["grades"]
assert g["detection_rate"] == 1.0 and g["true_positives"] == 3 and g["final_score"] >= 0.90
|