Spaces:
Sleeping
Sleeping
Sync: compliance mapping, anti-gaming, 55 tests, mandatory stdout format, pivoting+compliance weights
c1a5935 verified | """Tests for the Security Audit Environment.""" | |
| import sys, os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from server.security_audit_env_environment import SecurityAuditEnvironment | |
| from models import SecurityAuditAction, SecurityAuditObservation | |
| class TestReset: | |
| def test_clean_state(self): | |
| env = SecurityAuditEnvironment() | |
| obs = env.reset(scenario_id="easy") | |
| assert obs.done is False and obs.reward == 0.0 and obs.discovered_hosts == [] | |
| assert obs.steps_remaining == 30 and "QuickLaunch" in obs.message | |
| def test_clears_previous(self): | |
| env = SecurityAuditEnvironment() | |
| env.reset(scenario_id="easy") | |
| env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})) | |
| obs = env.reset(scenario_id="easy") | |
| assert obs.discovered_hosts == [] and env._episode_reward == 0.0 | |
| def test_all_scenarios(self): | |
| env = SecurityAuditEnvironment() | |
| for sid, steps in [("easy", 30), ("medium", 50), ("hard", 60)]: | |
| obs = env.reset(scenario_id=sid) | |
| assert obs.steps_remaining == steps and obs.done is False | |
| class TestActions: | |
| def test_list_tools(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="list_tools")) | |
| assert obs.available_tools is not None and len(obs.available_tools) == 10 and obs.reward == 0.0 | |
| def test_network_scan(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})) | |
| assert len(obs.discovered_hosts) == 2 and obs.reward > 0 | |
| def test_missing_tool_name(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="use_tool")) | |
| assert "Error" in obs.tool_output and obs.reward == -0.02 | |
| def test_submit_finding(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "SQL Injection in /api/login", "host": "10.0.1.10", "type": "SQL Injection", "severity": "Critical", "cwe": "CWE-89"})) | |
| assert obs.findings_submitted == 1 and obs.reward > 0 | |
| def test_submit_missing_fields(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "Test"})) | |
| assert obs.reward == -0.02 and "Missing" in obs.tool_output | |
| def test_generate_report(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="generate_report")) | |
| assert obs.done is True and "SECURITY AUDIT REPORT" in obs.tool_output and obs.metadata and "grades" in obs.metadata | |
| class TestRewards: | |
| def test_vary_by_action(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs1 = env.step(SecurityAuditAction(action_type="list_tools")) | |
| obs2 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})) | |
| assert obs1.reward == 0.0 and obs2.reward > 0.0 | |
| def test_difficulty_scaling(self): | |
| rewards = {} | |
| for sid in ["easy", "medium"]: | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id=sid) | |
| obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": f"10.0.{1 if sid=='easy' else 2}.0/24"})) | |
| rewards[sid] = obs.reward | |
| assert rewards["medium"] > rewards["easy"] | |
| def test_redundant_penalty(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs1 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.1.10"})) | |
| obs2 = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.1.10"})) | |
| assert obs2.reward < obs1.reward | |
| class TestEpisodeBoundaries: | |
| def test_step_limit(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = None | |
| for _ in range(31): | |
| obs = env.step(SecurityAuditAction(action_type="list_tools")) | |
| if obs.done: break | |
| assert obs.done is True | |
| def test_generate_report_ends(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="generate_report")) | |
| assert obs.done is True and obs.steps_remaining == 0 | |
| class TestProgressiveDiscovery: | |
| def test_hidden_hosts_initially(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="medium") | |
| obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"})) | |
| assert "10.0.2.10" in obs.discovered_hosts and "10.0.2.30" not in obs.discovered_hosts | |
| def test_unlock_after_finding(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="medium") | |
| env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"})) | |
| env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "SSRF", "host": "10.0.2.10", "endpoint": "/api/upload/image", "type": "Server-Side Request Forgery (SSRF)", "severity": "High", "cwe": "CWE-918"})) | |
| obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.2.0/24"})) | |
| assert "10.0.2.30" in obs.discovered_hosts | |
| class TestHoneypot: | |
| def test_penalty(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="hard") | |
| env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.3.0/24"})) | |
| obs = env.step(SecurityAuditAction(action_type="use_tool", tool_name="web_crawl", arguments={"host": "10.0.3.99"})) | |
| assert obs.reward < 0 | |
| class TestTruncation: | |
| def test_report_not_truncated(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = env.step(SecurityAuditAction(action_type="generate_report")) | |
| assert obs.done is True and obs.truncated is False | |
| def test_step_limit_truncated(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| obs = None | |
| for _ in range(31): | |
| obs = env.step(SecurityAuditAction(action_type="list_tools")) | |
| if obs.done: break | |
| assert obs.done is True and obs.truncated is True | |
| class TestPhaseTracking: | |
| def test_recon(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| assert env.step(SecurityAuditAction(action_type="list_tools")).current_phase == "reconnaissance" | |
| def test_enumeration(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| assert env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})).current_phase == "enumeration" | |
| def test_exploitation(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})) | |
| assert env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": "T", "host": "10.0.1.10", "severity": "H"})).current_phase == "exploitation" | |
| def test_reporting(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| assert env.step(SecurityAuditAction(action_type="generate_report")).current_phase == "reporting" | |
| class TestSeed: | |
| def test_deterministic(self): | |
| e1 = SecurityAuditEnvironment(); o1 = e1.reset(seed=42, scenario_id="easy") | |
| e2 = SecurityAuditEnvironment(); o2 = e2.reset(seed=42, scenario_id="easy") | |
| assert o1.message == o2.message | |
| def test_no_seed(self): | |
| env = SecurityAuditEnvironment() | |
| assert env.reset(scenario_id="easy").steps_remaining == 30 | |
| class TestFindingRewardCap: | |
| def test_diminishing(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| rewards = [] | |
| for i in range(6): | |
| obs = env.step(SecurityAuditAction(action_type="submit_finding", arguments={"title": f"Fake {i}", "host": "10.0.1.99", "severity": "Low"})) | |
| rewards.append(obs.reward) | |
| assert rewards[0] == 0.02 and rewards[5] == 0.0 | |
| class TestBaseline: | |
| def test_easy_scores_high(self): | |
| env = SecurityAuditEnvironment(); env.reset(scenario_id="easy") | |
| env.step(SecurityAuditAction(action_type="use_tool", tool_name="network_scan", arguments={"target": "10.0.1.0/24"})) | |
| for title, host, ep, typ, sev, cvss, cwe, owasp in [ | |
| ("SQL Injection", "10.0.1.10", "/api/login", "SQL Injection", "Critical", 9.8, "CWE-89", "A03:2021 - Injection"), | |
| ("Broken Auth", "10.0.1.10", "/admin", "Broken Authentication", "High", 8.1, "CWE-798", "A07:2021 - Identification and Authentication Failures"), | |
| ("Misconfig", "10.0.1.20", None, "Security Misconfiguration", "Critical", 9.1, "CWE-200", "A05:2021 - Security Misconfiguration"), | |
| ]: | |
| args = {"title": title, "host": host, "type": typ, "severity": sev, "cvss_score": cvss, "cwe": cwe, "owasp": owasp, "evidence": "Detailed evidence for " + title, "remediation": "Detailed remediation for " + title} | |
| if ep: args["endpoint"] = ep | |
| env.step(SecurityAuditAction(action_type="submit_finding", arguments=args)) | |
| obs = env.step(SecurityAuditAction(action_type="generate_report")) | |
| g = obs.metadata["grades"] | |
| assert g["detection_rate"] == 1.0 and g["true_positives"] == 3 and g["final_score"] >= 0.90 | |