security-audit-env / tests /test_environment.py
anshumanatrey's picture
Sync: compliance mapping, anti-gaming, 55 tests, mandatory stdout format, pivoting+compliance weights
c1a5935 verified
"""Tests for the Security Audit Environment."""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from server.security_audit_env_environment import SecurityAuditEnvironment
from models import SecurityAuditAction, SecurityAuditObservation
class TestReset:
def test_clean_state(self):
env = SecurityAuditEnvironment()
obs = env.reset(scenario_id="easy")
assert obs.done is False and obs.reward == 0.0 and obs.discovered_hosts == []
assert obs.steps_remaining == 30 and "QuickLaunch" in obs.message
def test_clears_previous(self):
env = SecurityAuditEnvironment()
env.reset(scenario_id="easy")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
obs = env.reset(scenario_id="easy")
assert obs.discovered_hosts == [] and env._episode_reward == 0.0
def test_all_scenarios(self):
env = SecurityAuditEnvironment()
for sid, steps in [("easy", 30), ("medium", 50), ("hard", 60)]:
obs = env.reset(scenario_id=sid)
assert obs.steps_remaining == steps and obs.done is False
class TestActions:
def test_list_tools(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="list_tools"))
assert obs.available_tools is not None and len(obs.available_tools) == 10 and obs.reward == 0.0
def test_network_scan(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
assert len(obs.discovered_hosts) == 2 and obs.reward > 0
def test_missing_tool_name(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="use_tool"))
assert "Error" in obs.tool_output and obs.reward == -0.02
def test_submit_finding(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "SQL Injection in /api/login", "host": "10.0.1.10", "type": "SQL Injection", "severity": "Critical", "cwe": "CWE-89"}))
assert obs.findings_submitted == 1 and obs.reward > 0
def test_submit_missing_fields(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "Test"}))
assert obs.reward == -0.02 and "Missing" in obs.tool_output
def test_generate_report(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="generate_report"))
assert obs.done is True and "SECURITY AUDIT REPORT" in obs.tool_output and obs.metadata and "grades" in obs.metadata
class TestRewards:
def test_vary_by_action(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs1 = env.step(SecurityAuditAction(action_type="list_tools"))
obs2 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
assert obs1.reward == 0.0 and obs2.reward > 0.0
def test_difficulty_scaling(self):
rewards = {}
for sid in ["easy", "medium"]:
env = SecurityAuditEnvironment(); env.reset(scenario_id=sid)
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": f"10.0.{1 if sid=='easy' else 2}.0/24"}))
rewards[sid] = obs.reward
assert rewards["medium"] > rewards["easy"]
def test_redundant_penalty(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs1 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.1.10"}))
obs2 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.1.10"}))
assert obs2.reward < obs1.reward
class TestEpisodeBoundaries:
def test_step_limit(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = None
for _ in range(31):
obs = env.step(SecurityAuditAction(action_type="list_tools"))
if obs.done: break
assert obs.done is True
def test_generate_report_ends(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="generate_report"))
assert obs.done is True and obs.steps_remaining == 0
class TestProgressiveDiscovery:
def test_hidden_hosts_initially(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="medium")
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"}))
assert "10.0.2.10" in obs.discovered_hosts and "10.0.2.30" not in obs.discovered_hosts
def test_unlock_after_finding(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="medium")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"}))
env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "SSRF", "host": "10.0.2.10", "endpoint": "/api/upload/image", "type": "Server-Side Request Forgery (SSRF)", "severity": "High", "cwe": "CWE-918"}))
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"}))
assert "10.0.2.30" in obs.discovered_hosts
class TestHoneypot:
def test_penalty(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="hard")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.3.0/24"}))
obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.3.99"}))
assert obs.reward < 0
class TestTruncation:
def test_report_not_truncated(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = env.step(SecurityAuditAction(action_type="generate_report"))
assert obs.done is True and obs.truncated is False
def test_step_limit_truncated(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
obs = None
for _ in range(31):
obs = env.step(SecurityAuditAction(action_type="list_tools"))
if obs.done: break
assert obs.done is True and obs.truncated is True
class TestPhaseTracking:
def test_recon(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
assert env.step(SecurityAuditAction(action_type="list_tools")).current_phase == "reconnaissance"
def test_enumeration(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
assert env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})).current_phase == "enumeration"
def test_exploitation(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
assert env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "T", "host": "10.0.1.10", "severity": "H"})).current_phase == "exploitation"
def test_reporting(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
assert env.step(SecurityAuditAction(action_type="generate_report")).current_phase == "reporting"
class TestSeed:
def test_deterministic(self):
e1 = SecurityAuditEnvironment(); o1 = e1.reset(seed=42, scenario_id="easy")
e2 = SecurityAuditEnvironment(); o2 = e2.reset(seed=42, scenario_id="easy")
assert o1.message == o2.message
def test_no_seed(self):
env = SecurityAuditEnvironment()
assert env.reset(scenario_id="easy").steps_remaining == 30
class TestFindingRewardCap:
def test_diminishing(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
rewards = []
for i in range(6):
obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": f"Fake {i}", "host": "10.0.1.99", "severity": "Low"}))
rewards.append(obs.reward)
assert rewards[0] == 0.02 and rewards[5] == 0.0
class TestBaseline:
def test_easy_scores_high(self):
env = SecurityAuditEnvironment(); env.reset(scenario_id="easy")
env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"}))
for title, host, ep, typ, sev, cvss, cwe, owasp in [
("SQL Injection", "10.0.1.10", "/api/login", "SQL Injection", "Critical", 9.8, "CWE-89", "A03:2021 - Injection"),
("Broken Auth", "10.0.1.10", "/admin", "Broken Authentication", "High", 8.1, "CWE-798", "A07:2021 - Identification and Authentication Failures"),
("Misconfig", "10.0.1.20", None, "Security Misconfiguration", "Critical", 9.1, "CWE-200", "A05:2021 - Security Misconfiguration"),
]:
args = {"title": title, "host": host, "type": typ, "severity": sev, "cvss_score": cvss, "cwe": cwe, "owasp": owasp, "evidence": "Detailed evidence for " + title, "remediation": "Detailed remediation for " + title}
if ep: args["endpoint"] = ep
env.step(SecurityAuditAction(action_type="submit_finding", arguments=args))
obs = env.step(SecurityAuditAction(action_type="generate_report"))
g = obs.metadata["grades"]
assert g["detection_rate"] == 1.0 and g["true_positives"] == 3 and g["final_score"] >= 0.90