| """Tests for Task 7 — Tool Router + Triage Solver.""" |
|
|
| import os |
| import sys |
|
|
| _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) |
| if _PROJECT_ROOT not in sys.path: |
| sys.path.insert(0, _PROJECT_ROOT) |
|
|
| from server.tool_router import ( |
| ToolRouter, |
| compute_triage_priority, |
| solve_triage_order, |
| ) |
| from server.threat_graph import ( |
| ThreatGraph, |
| AlertNode, |
| HostNode, |
| IOCNode, |
| ProcessNode, |
| ) |
|
|
|
|
| def _alert(aid="A1", severity="high", source="WS-001"): |
| return AlertNode(alert_id=aid, severity=severity, priority_score=1.0, source_host=source) |
|
|
|
|
| def _host(name="WS-001", crit="medium", status="compromised"): |
| return HostNode(hostname=name, subnet="corporate", business_criticality=crit, status=status) |
|
|
|
|
| def _full_evidence_graph(): |
| g = ThreatGraph() |
| g.add_host(_host("WS-001")) |
| g.add_ioc(IOCNode(ioc_value="1.1.1.1", ioc_type="ip", confidence=0.9)) |
| g.add_process(ProcessNode(process_id="WS-001:1", hostname="WS-001", process_name="x")) |
| return g |
|
|
|
|
| def test_triage_to_investigation_with_alerts(): |
| g = ThreatGraph() |
| g.add_alert(_alert()) |
| r = ToolRouter() |
| assert r.next_phase("triage", g, 10) == "investigation" |
|
|
|
|
| def test_triage_to_report_no_alerts(): |
| g = ThreatGraph() |
| r = ToolRouter() |
| assert r.next_phase("triage", g, 10) == "report" |
|
|
|
|
| def test_investigation_loops_then_exits(): |
| g = ThreatGraph() |
| r = ToolRouter() |
| out = "investigation" |
| for _ in range(r.MAX_INVESTIGATION_LOOPS + 1): |
| out = r.next_phase("investigation", g, 10) |
| assert out == "remediation" |
|
|
|
|
| def test_investigation_exits_on_sufficient_evidence(): |
| r = ToolRouter() |
| assert r.next_phase("investigation", _full_evidence_graph(), 10) == "remediation" |
|
|
|
|
| def test_remediation_exits_when_contained(): |
| r = ToolRouter() |
| g = ThreatGraph() |
| g.add_host(_host("WS-001", status="isolated")) |
| assert r.next_phase("remediation", g, 10) == "report" |
|
|
|
|
| def test_report_returns_done(): |
| r = ToolRouter() |
| assert r.next_phase("report", ThreatGraph(), 10) == "done" |
|
|
|
|
| def test_honor_pushback_rejects_no_graph_refs(): |
| r = ToolRouter() |
| ok, _ = r.honor_pushback("investigation", [], ThreatGraph()) |
| assert ok is False |
|
|
|
|
| def test_honor_pushback_accepts_valid_critical_alert(): |
| g = ThreatGraph() |
| g.add_alert(_alert("A1", severity="critical")) |
| r = ToolRouter() |
| ok, _ = r.honor_pushback("investigation", ["A1"], g) |
| assert ok is True |
|
|
|
|
| def test_triage_priority_higher_for_critical(): |
| g = ThreatGraph() |
| a_crit = _alert("A1", severity="critical") |
| a_low = _alert("A2", severity="low", source="WS-002") |
| h_crit = _host("WS-001", crit="critical") |
| h_low = _host("WS-002", crit="low") |
| s_crit = compute_triage_priority(a_crit, h_crit, g) |
| s_low = compute_triage_priority(a_low, h_low, g) |
| assert s_crit > s_low |
|
|
|
|
| def test_solve_triage_order_descending(): |
| g = ThreatGraph() |
| g.add_host(_host("WS-001", crit="critical")) |
| g.add_host(_host("WS-002", crit="medium")) |
| g.add_host(_host("WS-003", crit="low")) |
| g.add_alert(_alert("A1", severity="critical", source="WS-001")) |
| g.add_alert(_alert("A2", severity="medium", source="WS-002")) |
| g.add_alert(_alert("A3", severity="low", source="WS-003")) |
| order = solve_triage_order(g) |
| assert order == ["A1", "A2", "A3"] |
|
|