File size: 3,356 Bytes
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Tests for Task 7 — Tool Router + Triage Solver."""

import os
import sys

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from server.tool_router import (
    ToolRouter,
    compute_triage_priority,
    solve_triage_order,
)
from server.threat_graph import (
    ThreatGraph,
    AlertNode,
    HostNode,
    IOCNode,
    ProcessNode,
)


def _alert(aid="A1", severity="high", source="WS-001"):
    return AlertNode(alert_id=aid, severity=severity, priority_score=1.0, source_host=source)


def _host(name="WS-001", crit="medium", status="compromised"):
    return HostNode(hostname=name, subnet="corporate", business_criticality=crit, status=status)


def _full_evidence_graph():
    g = ThreatGraph()
    g.add_host(_host("WS-001"))
    g.add_ioc(IOCNode(ioc_value="1.1.1.1", ioc_type="ip", confidence=0.9))
    g.add_process(ProcessNode(process_id="WS-001:1", hostname="WS-001", process_name="x"))
    return g


def test_triage_to_investigation_with_alerts():
    g = ThreatGraph()
    g.add_alert(_alert())
    r = ToolRouter()
    assert r.next_phase("triage", g, 10) == "investigation"


def test_triage_to_report_no_alerts():
    g = ThreatGraph()
    r = ToolRouter()
    assert r.next_phase("triage", g, 10) == "report"


def test_investigation_loops_then_exits():
    g = ThreatGraph()  # evidence-free
    r = ToolRouter()
    out = "investigation"
    for _ in range(r.MAX_INVESTIGATION_LOOPS + 1):
        out = r.next_phase("investigation", g, 10)
    assert out == "remediation"


def test_investigation_exits_on_sufficient_evidence():
    r = ToolRouter()
    assert r.next_phase("investigation", _full_evidence_graph(), 10) == "remediation"


def test_remediation_exits_when_contained():
    r = ToolRouter()
    g = ThreatGraph()
    g.add_host(_host("WS-001", status="isolated"))
    assert r.next_phase("remediation", g, 10) == "report"


def test_report_returns_done():
    r = ToolRouter()
    assert r.next_phase("report", ThreatGraph(), 10) == "done"


def test_honor_pushback_rejects_no_graph_refs():
    r = ToolRouter()
    ok, _ = r.honor_pushback("investigation", [], ThreatGraph())
    assert ok is False


def test_honor_pushback_accepts_valid_critical_alert():
    g = ThreatGraph()
    g.add_alert(_alert("A1", severity="critical"))
    r = ToolRouter()
    ok, _ = r.honor_pushback("investigation", ["A1"], g)
    assert ok is True


def test_triage_priority_higher_for_critical():
    g = ThreatGraph()
    a_crit = _alert("A1", severity="critical")
    a_low = _alert("A2", severity="low", source="WS-002")
    h_crit = _host("WS-001", crit="critical")
    h_low = _host("WS-002", crit="low")
    s_crit = compute_triage_priority(a_crit, h_crit, g)
    s_low = compute_triage_priority(a_low, h_low, g)
    assert s_crit > s_low


def test_solve_triage_order_descending():
    g = ThreatGraph()
    g.add_host(_host("WS-001", crit="critical"))
    g.add_host(_host("WS-002", crit="medium"))
    g.add_host(_host("WS-003", crit="low"))
    g.add_alert(_alert("A1", severity="critical", source="WS-001"))
    g.add_alert(_alert("A2", severity="medium", source="WS-002"))
    g.add_alert(_alert("A3", severity="low", source="WS-003"))
    order = solve_triage_order(g)
    assert order == ["A1", "A2", "A3"]