Demo / server /action_validation.py
Ajayyy00
Replace SOAR playbooks with agentic micro-tools.
7211e63
"""3-gate action validation middleware: phase whitelist + schema + graph groundedness."""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .threat_graph import ThreatGraph
PHASE_TOOL_WHITELIST = {
"triage": {"read_alerts", "read_topology", "correlate_alerts"},
"investigation": {"query_host", "run_forensics", "add_ioc",
"enrich_ioc", "scan_host_vulnerabilities"},
"remediation": {"block_ioc", "kill_process", "isolate_segment",
"terminate_pid", "create_firewall_rule", "quarantine_file",
"request_human_approval"},
"report": {"submit_containment_plan"},
}
PHASE_VIOLATION = "PHASE_VIOLATION"
INVALID_PARAMS = "INVALID_PARAMS"
UNGROUNDED_ACTION = "UNGROUNDED_ACTION"
_REQUIRED_ARGS = {
"block_ioc": ["ioc_value"],
"kill_process": ["hostname", "process_name"],
"isolate_segment": ["target"],
"correlate_alerts": ["alert_ids"],
"enrich_ioc": ["ioc_value", "ioc_type"],
"scan_host_vulnerabilities": ["hostname"],
"terminate_pid": ["hostname", "pid"],
"create_firewall_rule": ["hostname", "target_ip", "action"],
"quarantine_file": ["hostname", "file_path"],
}
class ActionValidationMiddleware:
def validate(
self,
phase: str,
tool_name: str,
arguments: dict,
graph: "ThreatGraph",
) -> Optional[dict]:
# Gate 1 — Phase whitelist
allowed = PHASE_TOOL_WHITELIST.get(phase, set())
if tool_name not in allowed:
return {
"error": PHASE_VIOLATION,
"message": (
f"Tool '{tool_name}' is not allowed in phase '{phase}'. "
f"Allowed tools: {sorted(allowed)}"
),
"retry": False,
}
# Gate 2 — Argument presence (basic schema check)
required = _REQUIRED_ARGS.get(tool_name, [])
for arg in required:
if arg not in arguments:
return {
"error": INVALID_PARAMS,
"message": f"Missing required argument '{arg}' for tool '{tool_name}'",
"retry": True,
}
if tool_name == "correlate_alerts":
ids = arguments.get("alert_ids", [])
if not isinstance(ids, (list, tuple)) or len(ids) < 2:
return {
"error": INVALID_PARAMS,
"message": "correlate_alerts requires at least 2 alert_ids",
"retry": True,
}
# Gate 3 — Graph groundedness
if tool_name == "block_ioc":
if arguments["ioc_value"] not in graph.iocs:
return {
"error": UNGROUNDED_ACTION,
"message": "IOC not in Threat Graph. Run investigation first.",
"retry": True,
}
elif tool_name == "kill_process":
key = f"{arguments['hostname']}:{arguments['process_name']}"
if key not in graph.processes:
return {
"error": UNGROUNDED_ACTION,
"message": "Process not in Threat Graph.",
"retry": True,
}
elif tool_name == "enrich_ioc":
if arguments["ioc_value"] not in graph.iocs:
return {
"error": UNGROUNDED_ACTION,
"message": "IOC not known. Discover it during investigation first.",
"retry": True,
}
return None