File size: 4,901 Bytes
57e71f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Deterministic Tool Router (phase machine) + Triage Solver."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .threat_graph import ThreatGraph, AlertNode, HostNode
class ToolRouter:
PHASE_ORDER = ["triage", "investigation", "remediation", "report"]
MAX_INVESTIGATION_LOOPS = 4
MAX_REMEDIATION_LOOPS = 3
def __init__(self):
self._investigation_loop_count = 0
self._remediation_loop_count = 0
def next_phase(
self,
current_phase: str,
graph: "ThreatGraph",
steps_remaining: int,
) -> str:
if current_phase == "triage":
if len(graph.alerts) > 0:
return "investigation"
return "report"
if current_phase == "investigation":
if (
self._has_sufficient_evidence(graph)
or steps_remaining < 4
or self._investigation_loop_count >= self.MAX_INVESTIGATION_LOOPS
):
return "remediation"
self._investigation_loop_count += 1
return "investigation"
if current_phase == "remediation":
if (
self._all_threats_contained(graph)
or steps_remaining < 2
or self._remediation_loop_count >= self.MAX_REMEDIATION_LOOPS
):
return "report"
if (
self._remediation_loop_count < self.MAX_REMEDIATION_LOOPS
and not self._all_threats_contained(graph)
and steps_remaining >= 4
):
self._remediation_loop_count += 1
return "investigation"
return "report"
if current_phase == "report":
return "done"
return "done"
def _has_sufficient_evidence(self, graph: "ThreatGraph") -> bool:
has_unhealthy_host = any(h.status != "healthy" for h in graph.hosts.values())
has_ioc = len(graph.iocs) > 0
has_process = len(graph.processes) > 0
return has_unhealthy_host and has_ioc and has_process
def _all_threats_contained(self, graph: "ThreatGraph") -> bool:
suspicious_or_compromised = [
h for h in graph.hosts.values()
if h.status in ("suspicious", "compromised")
]
if not suspicious_or_compromised:
return True
return all(
h.status in ("isolated", "contained") for h in graph.hosts.values()
)
def reset(self):
self._investigation_loop_count = 0
self._remediation_loop_count = 0
def honor_pushback(
self,
proposed_next_phase: str,
justification_graph_refs: list[str],
graph: "ThreatGraph",
) -> tuple[bool, str]:
if proposed_next_phase not in self.PHASE_ORDER and proposed_next_phase != "done":
return False, f"invalid phase '{proposed_next_phase}'"
if not justification_graph_refs:
return False, "no justification graph references provided"
all_node_ids = (
set(graph.alerts.keys())
| set(graph.hosts.keys())
| set(graph.processes.keys())
| set(graph.iocs.keys())
| set(graph.vulnerabilities.keys())
)
for ref in justification_graph_refs:
if ref not in all_node_ids:
return False, f"reference '{ref}' not present in graph"
has_critical_alert = any(
ref in graph.alerts and graph.alerts[ref].severity in ("high", "critical")
for ref in justification_graph_refs
)
if not has_critical_alert:
return False, "at least one referenced alert must be high/critical severity"
return True, ""
# ===========================================================================
# Triage Solver
# ===========================================================================
SEVERITY_W = {"low": 1, "medium": 3, "high": 7, "critical": 15}
CRITICALITY_W = {"low": 1, "medium": 2, "high": 4, "critical": 8}
REACHABILITY_SCALE = 10
def compute_triage_priority(
alert: "AlertNode",
host: "HostNode",
graph: "ThreatGraph",
) -> float:
blast_radius = sum(1 for e in graph.edges if e.source_id == host.hostname)
return (
SEVERITY_W[alert.severity]
* CRITICALITY_W[host.business_criticality]
* (1 + blast_radius / REACHABILITY_SCALE)
)
def solve_triage_order(graph: "ThreatGraph") -> list[str]:
scored: list[tuple[float, str]] = []
for alert in graph.alerts.values():
host = graph.hosts.get(alert.source_host)
if host is None:
continue
score = compute_triage_priority(alert, host, graph)
scored.append((score, alert.alert_id))
scored.sort(key=lambda t: t[0], reverse=True)
return [aid for _, aid in scored]
|