File size: 6,201 Bytes
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Tests for Task 9 — 4 New Action Handlers + Enhanced Existing Handlers."""

import os
import sys

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from server.play_environment import CyberSOCEnvironment
from server.threat_graph import HostNode, ProcessNode, IOCNode, AlertNode, Edge
from models import CorrelateAlerts, EnrichIOC, ScanHostVulnerabilities, TriggerPlaybook


def _env_with_graph():
    """Return a reset env with a seeded threat graph."""
    env = CyberSOCEnvironment()
    env.reset(task_id="easy")
    return env


def _add_alerts(env, alert_ids):
    """Add AlertNodes to the threat graph."""
    for aid in alert_ids:
        if aid not in env._threat_graph.alerts:
            env._threat_graph.add_alert(AlertNode(
                alert_id=aid, severity="high", priority_score=5.0, source_host="WS-001"
            ))


def _add_ioc(env, value="1.2.3.4"):
    if value not in env._threat_graph.iocs:
        env._threat_graph.add_ioc(IOCNode(ioc_value=value, ioc_type="ip", confidence=0.9))


def _add_host(env, hostname="WS-FAKE"):
    if hostname not in env._threat_graph.hosts:
        env._threat_graph.add_host(HostNode(
            hostname=hostname, subnet="corporate",
            business_criticality="medium", status="compromised"
        ))


def test_correlate_alerts_returns_correlation_results():
    env = _env_with_graph()
    _add_alerts(env, ["A1", "A2"])
    result = env._handle_correlate_alerts(CorrelateAlerts(alert_ids=["A1", "A2"]))
    assert "correlation_results" in result
    assert "correlation_score" in result["correlation_results"]


def test_correlate_alerts_error_on_single_id():
    env = _env_with_graph()
    _add_alerts(env, ["A1"])
    result = env._handle_correlate_alerts(CorrelateAlerts(alert_ids=["A1", "MISSING"]))
    # fewer than 2 alerts found in graph → error
    assert "error" in result


def test_enrich_ioc_updates_graph():
    env = _env_with_graph()
    # Grab any IOC already in graph (seeded from task def)
    ioc_value = next(iter(env._threat_graph.iocs), None)
    if ioc_value is None:
        _add_ioc(env)
        ioc_value = "1.2.3.4"
    env._handle_enrich_ioc(EnrichIOC(ioc_value=ioc_value, ioc_type="ip"))
    assert env._threat_graph.iocs[ioc_value].enriched is True


def test_enrich_ioc_error_when_not_in_graph():
    env = _env_with_graph()
    result = env._handle_enrich_ioc(EnrichIOC(ioc_value="not.in.graph", ioc_type="ip"))
    assert "error" in result


def test_scan_vulnerabilities_adds_vuln_nodes():
    env = _env_with_graph()
    # Use a host that exists in the graph
    hostname = next(iter(env._threat_graph.hosts), None)
    if hostname is None:
        _add_host(env)
        hostname = "WS-FAKE"
    # Seed a vulnerability_chain entry for this host
    env._task_def["vulnerability_chain"] = [{
        "hostname": hostname,
        "cve_id": "CVE-2024-9999",
        "cvss_score": 9.8,
        "exploitability": "active",
        "patch_available": True,
        "threat_id": "T1",
    }]
    env._handle_scan_vulnerabilities(ScanHostVulnerabilities(hostname=hostname))
    assert len(env._threat_graph.vulnerabilities) > 0


def test_scan_vulnerabilities_marks_host_scanned():
    env = _env_with_graph()
    hostname = next(iter(env._threat_graph.hosts), None)
    if hostname is None:
        _add_host(env)
        hostname = "WS-FAKE"
    env._task_def["vulnerability_chain"] = []
    env._handle_scan_vulnerabilities(ScanHostVulnerabilities(hostname=hostname))
    assert env._threat_graph.hosts[hostname].scanned is True


def test_trigger_playbook_fails_without_prerequisites():
    env = _env_with_graph()
    hostname = next(iter(env._threat_graph.hosts), "WS-001")
    result = env._handle_trigger_playbook(
        TriggerPlaybook(playbook_name="ransomware_containment", target=hostname)
    )
    assert "error" in result


def test_trigger_playbook_adds_to_triggered_list():
    env = _env_with_graph()
    hostname = next(iter(env._threat_graph.hosts), None)
    if hostname is None:
        _add_host(env)
        hostname = "WS-FAKE"

    # Satisfy prerequisites: forensics_run_on_target + process_identified
    env._state.scanned_hosts.append(hostname)
    if not any(p.hostname == hostname for p in env._threat_graph.processes.values()):
        env._threat_graph.add_process(ProcessNode(
            process_id=f"{hostname}:1", hostname=hostname, process_name="evil.exe"
        ))

    result = env._handle_trigger_playbook(
        TriggerPlaybook(playbook_name="ransomware_containment", target=hostname)
    )
    assert "error" not in result
    assert "ransomware_containment" in env._state.triggered_playbooks


def test_query_host_returns_process_tree():
    env = _env_with_graph()
    hostname = next(iter(env._threat_graph.hosts), None)
    if hostname is None:
        _add_host(env)
        hostname = "WS-FAKE"
        # also add to host_index
        env._host_index[hostname] = {
            "hostname": hostname, "subnet": "corporate",
            "status": "compromised", "running_processes": [], "criticality": 0.5
        }
    env._last_obs_extras = {}
    env._handle_query_host(type("QH", (), {"hostname": hostname})())
    assert "process_tree" in env._last_obs_extras


def test_isolate_single_host_sets_isolated():
    env = _env_with_graph()
    hostname = next(iter(env._threat_graph.hosts), None)
    if hostname is None:
        _add_host(env)
        hostname = "WS-FAKE"
        env._host_index[hostname] = {
            "hostname": hostname, "subnet": "corporate",
            "status": "compromised", "running_processes": [], "criticality": 0.5
        }

    from models import IsolateSegment
    action = IsolateSegment(target_host=hostname, reason="test")
    env._handle_isolate_segment(action)
    assert env._host_index[hostname]["status"] == "isolated"
    # Only this host isolated, not all of its subnet
    subnet = env._host_index[hostname].get("subnet", "corporate")
    subnet_hosts = env._network.get(subnet, [])
    still_up = sum(1 for h in subnet_hosts if h["status"] != "isolated")
    assert still_up > 0