| """Deterministic Tool Router (phase machine) + Triage Solver.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import TYPE_CHECKING |
|
|
| if TYPE_CHECKING: |
| from .threat_graph import ThreatGraph, AlertNode, HostNode |
|
|
|
|
| class ToolRouter: |
| PHASE_ORDER = ["triage", "investigation", "remediation", "report"] |
| MAX_INVESTIGATION_LOOPS = 4 |
| MAX_REMEDIATION_LOOPS = 3 |
|
|
| def __init__(self): |
| self._investigation_loop_count = 0 |
| self._remediation_loop_count = 0 |
|
|
| def next_phase( |
| self, |
| current_phase: str, |
| graph: "ThreatGraph", |
| steps_remaining: int, |
| ) -> str: |
| if current_phase == "triage": |
| if len(graph.alerts) > 0: |
| return "investigation" |
| return "report" |
|
|
| if current_phase == "investigation": |
| if ( |
| self._has_sufficient_evidence(graph) |
| or steps_remaining < 4 |
| or self._investigation_loop_count >= self.MAX_INVESTIGATION_LOOPS |
| ): |
| return "remediation" |
| self._investigation_loop_count += 1 |
| return "investigation" |
|
|
| if current_phase == "remediation": |
| if ( |
| self._all_threats_contained(graph) |
| or steps_remaining < 2 |
| or self._remediation_loop_count >= self.MAX_REMEDIATION_LOOPS |
| ): |
| return "report" |
| if ( |
| self._remediation_loop_count < self.MAX_REMEDIATION_LOOPS |
| and not self._all_threats_contained(graph) |
| and steps_remaining >= 4 |
| ): |
| self._remediation_loop_count += 1 |
| return "investigation" |
| return "report" |
|
|
| if current_phase == "report": |
| return "done" |
|
|
| return "done" |
|
|
| def _has_sufficient_evidence(self, graph: "ThreatGraph") -> bool: |
| has_unhealthy_host = any(h.status != "healthy" for h in graph.hosts.values()) |
| has_ioc = len(graph.iocs) > 0 |
| has_process = len(graph.processes) > 0 |
| return has_unhealthy_host and has_ioc and has_process |
|
|
| def _all_threats_contained(self, graph: "ThreatGraph") -> bool: |
| suspicious_or_compromised = [ |
| h for h in graph.hosts.values() |
| if h.status in ("suspicious", "compromised") |
| ] |
| if not suspicious_or_compromised: |
| return True |
| return all( |
| h.status in ("isolated", "contained") for h in graph.hosts.values() |
| ) |
|
|
| def reset(self): |
| self._investigation_loop_count = 0 |
| self._remediation_loop_count = 0 |
|
|
| def honor_pushback( |
| self, |
| proposed_next_phase: str, |
| justification_graph_refs: list[str], |
| graph: "ThreatGraph", |
| ) -> tuple[bool, str]: |
| if proposed_next_phase not in self.PHASE_ORDER and proposed_next_phase != "done": |
| return False, f"invalid phase '{proposed_next_phase}'" |
| if not justification_graph_refs: |
| return False, "no justification graph references provided" |
|
|
| all_node_ids = ( |
| set(graph.alerts.keys()) |
| | set(graph.hosts.keys()) |
| | set(graph.processes.keys()) |
| | set(graph.iocs.keys()) |
| | set(graph.vulnerabilities.keys()) |
| ) |
| for ref in justification_graph_refs: |
| if ref not in all_node_ids: |
| return False, f"reference '{ref}' not present in graph" |
|
|
| has_critical_alert = any( |
| ref in graph.alerts and graph.alerts[ref].severity in ("high", "critical") |
| for ref in justification_graph_refs |
| ) |
| if not has_critical_alert: |
| return False, "at least one referenced alert must be high/critical severity" |
|
|
| return True, "" |
|
|
|
|
| |
| |
| |
|
|
| SEVERITY_W = {"low": 1, "medium": 3, "high": 7, "critical": 15} |
| CRITICALITY_W = {"low": 1, "medium": 2, "high": 4, "critical": 8} |
| REACHABILITY_SCALE = 10 |
|
|
|
|
| def compute_triage_priority( |
| alert: "AlertNode", |
| host: "HostNode", |
| graph: "ThreatGraph", |
| ) -> float: |
| blast_radius = sum(1 for e in graph.edges if e.source_id == host.hostname) |
| return ( |
| SEVERITY_W[alert.severity] |
| * CRITICALITY_W[host.business_criticality] |
| * (1 + blast_radius / REACHABILITY_SCALE) |
| ) |
|
|
|
|
| def solve_triage_order(graph: "ThreatGraph") -> list[str]: |
| scored: list[tuple[float, str]] = [] |
| for alert in graph.alerts.values(): |
| host = graph.hosts.get(alert.source_host) |
| if host is None: |
| continue |
| score = compute_triage_priority(alert, host, graph) |
| scored.append((score, alert.alert_id)) |
| scored.sort(key=lambda t: t[0], reverse=True) |
| return [aid for _, aid in scored] |
|
|