Spaces:
Sleeping
Sleeping
| """Tests for Task 9 — 4 New Action Handlers + Enhanced Existing Handlers.""" | |
| import os | |
| import sys | |
| _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) | |
| if _PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, _PROJECT_ROOT) | |
| from server.play_environment import CyberSOCEnvironment | |
| from server.threat_graph import HostNode, ProcessNode, IOCNode, AlertNode, Edge | |
| from models import CorrelateAlerts, EnrichIOC, ScanHostVulnerabilities, TriggerPlaybook | |
| def _env_with_graph(): | |
| """Return a reset env with a seeded threat graph.""" | |
| env = CyberSOCEnvironment() | |
| env.reset(task_id="easy") | |
| return env | |
| def _add_alerts(env, alert_ids): | |
| """Add AlertNodes to the threat graph.""" | |
| for aid in alert_ids: | |
| if aid not in env._threat_graph.alerts: | |
| env._threat_graph.add_alert(AlertNode( | |
| alert_id=aid, severity="high", priority_score=5.0, source_host="WS-001" | |
| )) | |
| def _add_ioc(env, value="1.2.3.4"): | |
| if value not in env._threat_graph.iocs: | |
| env._threat_graph.add_ioc(IOCNode(ioc_value=value, ioc_type="ip", confidence=0.9)) | |
| def _add_host(env, hostname="WS-FAKE"): | |
| if hostname not in env._threat_graph.hosts: | |
| env._threat_graph.add_host(HostNode( | |
| hostname=hostname, subnet="corporate", | |
| business_criticality="medium", status="compromised" | |
| )) | |
| def test_correlate_alerts_returns_correlation_results(): | |
| env = _env_with_graph() | |
| _add_alerts(env, ["A1", "A2"]) | |
| result = env._handle_correlate_alerts(CorrelateAlerts(alert_ids=["A1", "A2"])) | |
| assert "correlation_results" in result | |
| assert "correlation_score" in result["correlation_results"] | |
| def test_correlate_alerts_error_on_single_id(): | |
| env = _env_with_graph() | |
| _add_alerts(env, ["A1"]) | |
| result = env._handle_correlate_alerts(CorrelateAlerts(alert_ids=["A1", "MISSING"])) | |
| # fewer than 2 alerts found in graph → error | |
| assert "error" in result | |
| def test_enrich_ioc_updates_graph(): | |
| env = _env_with_graph() | |
| # Grab any IOC already in graph (seeded from task def) | |
| ioc_value = next(iter(env._threat_graph.iocs), None) | |
| if ioc_value is None: | |
| _add_ioc(env) | |
| ioc_value = "1.2.3.4" | |
| env._handle_enrich_ioc(EnrichIOC(ioc_value=ioc_value, ioc_type="ip")) | |
| assert env._threat_graph.iocs[ioc_value].enriched is True | |
| def test_enrich_ioc_error_when_not_in_graph(): | |
| env = _env_with_graph() | |
| result = env._handle_enrich_ioc(EnrichIOC(ioc_value="not.in.graph", ioc_type="ip")) | |
| assert "error" in result | |
| def test_scan_vulnerabilities_adds_vuln_nodes(): | |
| env = _env_with_graph() | |
| # Use a host that exists in the graph | |
| hostname = next(iter(env._threat_graph.hosts), None) | |
| if hostname is None: | |
| _add_host(env) | |
| hostname = "WS-FAKE" | |
| # Seed a vulnerability_chain entry for this host | |
| env._task_def["vulnerability_chain"] = [{ | |
| "hostname": hostname, | |
| "cve_id": "CVE-2024-9999", | |
| "cvss_score": 9.8, | |
| "exploitability": "active", | |
| "patch_available": True, | |
| "threat_id": "T1", | |
| }] | |
| env._handle_scan_vulnerabilities(ScanHostVulnerabilities(hostname=hostname)) | |
| assert len(env._threat_graph.vulnerabilities) > 0 | |
| def test_scan_vulnerabilities_marks_host_scanned(): | |
| env = _env_with_graph() | |
| hostname = next(iter(env._threat_graph.hosts), None) | |
| if hostname is None: | |
| _add_host(env) | |
| hostname = "WS-FAKE" | |
| env._task_def["vulnerability_chain"] = [] | |
| env._handle_scan_vulnerabilities(ScanHostVulnerabilities(hostname=hostname)) | |
| assert env._threat_graph.hosts[hostname].scanned is True | |
| def test_trigger_playbook_fails_without_prerequisites(): | |
| env = _env_with_graph() | |
| hostname = next(iter(env._threat_graph.hosts), "WS-001") | |
| result = env._handle_trigger_playbook( | |
| TriggerPlaybook(playbook_name="ransomware_containment", target=hostname) | |
| ) | |
| assert "error" in result | |
| def test_trigger_playbook_adds_to_triggered_list(): | |
| env = _env_with_graph() | |
| hostname = next(iter(env._threat_graph.hosts), None) | |
| if hostname is None: | |
| _add_host(env) | |
| hostname = "WS-FAKE" | |
| # Satisfy prerequisites: forensics_run_on_target + process_identified | |
| env._state.scanned_hosts.append(hostname) | |
| if not any(p.hostname == hostname for p in env._threat_graph.processes.values()): | |
| env._threat_graph.add_process(ProcessNode( | |
| process_id=f"{hostname}:1", hostname=hostname, process_name="evil.exe" | |
| )) | |
| result = env._handle_trigger_playbook( | |
| TriggerPlaybook(playbook_name="ransomware_containment", target=hostname) | |
| ) | |
| assert "error" not in result | |
| assert "ransomware_containment" in env._state.triggered_playbooks | |
| def test_query_host_returns_process_tree(): | |
| env = _env_with_graph() | |
| hostname = next(iter(env._threat_graph.hosts), None) | |
| if hostname is None: | |
| _add_host(env) | |
| hostname = "WS-FAKE" | |
| # also add to host_index | |
| env._host_index[hostname] = { | |
| "hostname": hostname, "subnet": "corporate", | |
| "status": "compromised", "running_processes": [], "criticality": 0.5 | |
| } | |
| env._last_obs_extras = {} | |
| env._handle_query_host(type("QH", (), {"hostname": hostname})()) | |
| assert "process_tree" in env._last_obs_extras | |
| def test_isolate_single_host_sets_isolated(): | |
| env = _env_with_graph() | |
| hostname = next(iter(env._threat_graph.hosts), None) | |
| if hostname is None: | |
| _add_host(env) | |
| hostname = "WS-FAKE" | |
| env._host_index[hostname] = { | |
| "hostname": hostname, "subnet": "corporate", | |
| "status": "compromised", "running_processes": [], "criticality": 0.5 | |
| } | |
| from models import IsolateSegment | |
| action = IsolateSegment(target_host=hostname, reason="test") | |
| env._handle_isolate_segment(action) | |
| assert env._host_index[hostname]["status"] == "isolated" | |
| # Only this host isolated, not all of its subnet | |
| subnet = env._host_index[hostname].get("subnet", "corporate") | |
| subnet_hosts = env._network.get(subnet, []) | |
| still_up = sum(1 for h in subnet_hosts if h["status"] != "isolated") | |
| assert still_up > 0 | |