File size: 4,901 Bytes
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Deterministic Tool Router (phase machine) + Triage Solver."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .threat_graph import ThreatGraph, AlertNode, HostNode


class ToolRouter:
    PHASE_ORDER = ["triage", "investigation", "remediation", "report"]
    MAX_INVESTIGATION_LOOPS = 4
    MAX_REMEDIATION_LOOPS = 3

    def __init__(self):
        self._investigation_loop_count = 0
        self._remediation_loop_count = 0

    def next_phase(
        self,
        current_phase: str,
        graph: "ThreatGraph",
        steps_remaining: int,
    ) -> str:
        if current_phase == "triage":
            if len(graph.alerts) > 0:
                return "investigation"
            return "report"

        if current_phase == "investigation":
            if (
                self._has_sufficient_evidence(graph)
                or steps_remaining < 4
                or self._investigation_loop_count >= self.MAX_INVESTIGATION_LOOPS
            ):
                return "remediation"
            self._investigation_loop_count += 1
            return "investigation"

        if current_phase == "remediation":
            if (
                self._all_threats_contained(graph)
                or steps_remaining < 2
                or self._remediation_loop_count >= self.MAX_REMEDIATION_LOOPS
            ):
                return "report"
            if (
                self._remediation_loop_count < self.MAX_REMEDIATION_LOOPS
                and not self._all_threats_contained(graph)
                and steps_remaining >= 4
            ):
                self._remediation_loop_count += 1
                return "investigation"
            return "report"

        if current_phase == "report":
            return "done"

        return "done"

    def _has_sufficient_evidence(self, graph: "ThreatGraph") -> bool:
        has_unhealthy_host = any(h.status != "healthy" for h in graph.hosts.values())
        has_ioc = len(graph.iocs) > 0
        has_process = len(graph.processes) > 0
        return has_unhealthy_host and has_ioc and has_process

    def _all_threats_contained(self, graph: "ThreatGraph") -> bool:
        suspicious_or_compromised = [
            h for h in graph.hosts.values()
            if h.status in ("suspicious", "compromised")
        ]
        if not suspicious_or_compromised:
            return True
        return all(
            h.status in ("isolated", "contained") for h in graph.hosts.values()
        )

    def reset(self):
        self._investigation_loop_count = 0
        self._remediation_loop_count = 0

    def honor_pushback(
        self,
        proposed_next_phase: str,
        justification_graph_refs: list[str],
        graph: "ThreatGraph",
    ) -> tuple[bool, str]:
        if proposed_next_phase not in self.PHASE_ORDER and proposed_next_phase != "done":
            return False, f"invalid phase '{proposed_next_phase}'"
        if not justification_graph_refs:
            return False, "no justification graph references provided"

        all_node_ids = (
            set(graph.alerts.keys())
            | set(graph.hosts.keys())
            | set(graph.processes.keys())
            | set(graph.iocs.keys())
            | set(graph.vulnerabilities.keys())
        )
        for ref in justification_graph_refs:
            if ref not in all_node_ids:
                return False, f"reference '{ref}' not present in graph"

        has_critical_alert = any(
            ref in graph.alerts and graph.alerts[ref].severity in ("high", "critical")
            for ref in justification_graph_refs
        )
        if not has_critical_alert:
            return False, "at least one referenced alert must be high/critical severity"

        return True, ""


# ===========================================================================
# Triage Solver
# ===========================================================================

SEVERITY_W = {"low": 1, "medium": 3, "high": 7, "critical": 15}
CRITICALITY_W = {"low": 1, "medium": 2, "high": 4, "critical": 8}
REACHABILITY_SCALE = 10


def compute_triage_priority(
    alert: "AlertNode",
    host: "HostNode",
    graph: "ThreatGraph",
) -> float:
    blast_radius = sum(1 for e in graph.edges if e.source_id == host.hostname)
    return (
        SEVERITY_W[alert.severity]
        * CRITICALITY_W[host.business_criticality]
        * (1 + blast_radius / REACHABILITY_SCALE)
    )


def solve_triage_order(graph: "ThreatGraph") -> list[str]:
    scored: list[tuple[float, str]] = []
    for alert in graph.alerts.values():
        host = graph.hosts.get(alert.source_host)
        if host is None:
            continue
        score = compute_triage_priority(alert, host, graph)
        scored.append((score, alert.alert_id))
    scored.sort(key=lambda t: t[0], reverse=True)
    return [aid for _, aid in scored]