File size: 3,744 Bytes
57e71f8 7211e63 57e71f8 7211e63 57e71f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """3-gate action validation middleware: phase whitelist + schema + graph groundedness."""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .threat_graph import ThreatGraph
PHASE_TOOL_WHITELIST = {
"triage": {"read_alerts", "read_topology", "correlate_alerts"},
"investigation": {"query_host", "run_forensics", "add_ioc",
"enrich_ioc", "scan_host_vulnerabilities"},
"remediation": {"block_ioc", "kill_process", "isolate_segment",
"terminate_pid", "create_firewall_rule", "quarantine_file",
"request_human_approval"},
"report": {"submit_containment_plan"},
}
PHASE_VIOLATION = "PHASE_VIOLATION"
INVALID_PARAMS = "INVALID_PARAMS"
UNGROUNDED_ACTION = "UNGROUNDED_ACTION"
_REQUIRED_ARGS = {
"block_ioc": ["ioc_value"],
"kill_process": ["hostname", "process_name"],
"isolate_segment": ["target"],
"correlate_alerts": ["alert_ids"],
"enrich_ioc": ["ioc_value", "ioc_type"],
"scan_host_vulnerabilities": ["hostname"],
"terminate_pid": ["hostname", "pid"],
"create_firewall_rule": ["hostname", "target_ip", "action"],
"quarantine_file": ["hostname", "file_path"],
}
class ActionValidationMiddleware:
def validate(
self,
phase: str,
tool_name: str,
arguments: dict,
graph: "ThreatGraph",
) -> Optional[dict]:
# Gate 1 — Phase whitelist
allowed = PHASE_TOOL_WHITELIST.get(phase, set())
if tool_name not in allowed:
return {
"error": PHASE_VIOLATION,
"message": (
f"Tool '{tool_name}' is not allowed in phase '{phase}'. "
f"Allowed tools: {sorted(allowed)}"
),
"retry": False,
}
# Gate 2 — Argument presence (basic schema check)
required = _REQUIRED_ARGS.get(tool_name, [])
for arg in required:
if arg not in arguments:
return {
"error": INVALID_PARAMS,
"message": f"Missing required argument '{arg}' for tool '{tool_name}'",
"retry": True,
}
if tool_name == "correlate_alerts":
ids = arguments.get("alert_ids", [])
if not isinstance(ids, (list, tuple)) or len(ids) < 2:
return {
"error": INVALID_PARAMS,
"message": "correlate_alerts requires at least 2 alert_ids",
"retry": True,
}
# Gate 3 — Graph groundedness
if tool_name == "block_ioc":
if arguments["ioc_value"] not in graph.iocs:
return {
"error": UNGROUNDED_ACTION,
"message": "IOC not in Threat Graph. Run investigation first.",
"retry": True,
}
elif tool_name == "kill_process":
key = f"{arguments['hostname']}:{arguments['process_name']}"
if key not in graph.processes:
return {
"error": UNGROUNDED_ACTION,
"message": "Process not in Threat Graph.",
"retry": True,
}
elif tool_name == "enrich_ioc":
if arguments["ioc_value"] not in graph.iocs:
return {
"error": UNGROUNDED_ACTION,
"message": "IOC not known. Discover it during investigation first.",
"retry": True,
}
return None
|