Demo / server /tool_router.py
Ajayyy00
Initial commit of CyberSOC upgraded RLVR environment
57e71f8
"""Deterministic Tool Router (phase machine) + Triage Solver."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .threat_graph import ThreatGraph, AlertNode, HostNode
class ToolRouter:
PHASE_ORDER = ["triage", "investigation", "remediation", "report"]
MAX_INVESTIGATION_LOOPS = 4
MAX_REMEDIATION_LOOPS = 3
def __init__(self):
self._investigation_loop_count = 0
self._remediation_loop_count = 0
def next_phase(
self,
current_phase: str,
graph: "ThreatGraph",
steps_remaining: int,
) -> str:
if current_phase == "triage":
if len(graph.alerts) > 0:
return "investigation"
return "report"
if current_phase == "investigation":
if (
self._has_sufficient_evidence(graph)
or steps_remaining < 4
or self._investigation_loop_count >= self.MAX_INVESTIGATION_LOOPS
):
return "remediation"
self._investigation_loop_count += 1
return "investigation"
if current_phase == "remediation":
if (
self._all_threats_contained(graph)
or steps_remaining < 2
or self._remediation_loop_count >= self.MAX_REMEDIATION_LOOPS
):
return "report"
if (
self._remediation_loop_count < self.MAX_REMEDIATION_LOOPS
and not self._all_threats_contained(graph)
and steps_remaining >= 4
):
self._remediation_loop_count += 1
return "investigation"
return "report"
if current_phase == "report":
return "done"
return "done"
def _has_sufficient_evidence(self, graph: "ThreatGraph") -> bool:
has_unhealthy_host = any(h.status != "healthy" for h in graph.hosts.values())
has_ioc = len(graph.iocs) > 0
has_process = len(graph.processes) > 0
return has_unhealthy_host and has_ioc and has_process
def _all_threats_contained(self, graph: "ThreatGraph") -> bool:
suspicious_or_compromised = [
h for h in graph.hosts.values()
if h.status in ("suspicious", "compromised")
]
if not suspicious_or_compromised:
return True
return all(
h.status in ("isolated", "contained") for h in graph.hosts.values()
)
def reset(self):
self._investigation_loop_count = 0
self._remediation_loop_count = 0
def honor_pushback(
self,
proposed_next_phase: str,
justification_graph_refs: list[str],
graph: "ThreatGraph",
) -> tuple[bool, str]:
if proposed_next_phase not in self.PHASE_ORDER and proposed_next_phase != "done":
return False, f"invalid phase '{proposed_next_phase}'"
if not justification_graph_refs:
return False, "no justification graph references provided"
all_node_ids = (
set(graph.alerts.keys())
| set(graph.hosts.keys())
| set(graph.processes.keys())
| set(graph.iocs.keys())
| set(graph.vulnerabilities.keys())
)
for ref in justification_graph_refs:
if ref not in all_node_ids:
return False, f"reference '{ref}' not present in graph"
has_critical_alert = any(
ref in graph.alerts and graph.alerts[ref].severity in ("high", "critical")
for ref in justification_graph_refs
)
if not has_critical_alert:
return False, "at least one referenced alert must be high/critical severity"
return True, ""
# ===========================================================================
# Triage Solver
# ===========================================================================
SEVERITY_W = {"low": 1, "medium": 3, "high": 7, "critical": 15}
CRITICALITY_W = {"low": 1, "medium": 2, "high": 4, "critical": 8}
REACHABILITY_SCALE = 10
def compute_triage_priority(
alert: "AlertNode",
host: "HostNode",
graph: "ThreatGraph",
) -> float:
blast_radius = sum(1 for e in graph.edges if e.source_id == host.hostname)
return (
SEVERITY_W[alert.severity]
* CRITICALITY_W[host.business_criticality]
* (1 + blast_radius / REACHABILITY_SCALE)
)
def solve_triage_order(graph: "ThreatGraph") -> list[str]:
scored: list[tuple[float, str]] = []
for alert in graph.alerts.values():
host = graph.hosts.get(alert.source_host)
if host is None:
continue
score = compute_triage_priority(alert, host, graph)
scored.append((score, alert.alert_id))
scored.sort(key=lambda t: t[0], reverse=True)
return [aid for _, aid in scored]