"""3-gate action validation middleware: phase whitelist + schema + graph groundedness.""" from __future__ import annotations from typing import Optional, TYPE_CHECKING if TYPE_CHECKING: from .threat_graph import ThreatGraph PHASE_TOOL_WHITELIST = { "triage": {"read_alerts", "read_topology", "correlate_alerts"}, "investigation": {"query_host", "run_forensics", "add_ioc", "enrich_ioc", "scan_host_vulnerabilities"}, "remediation": {"block_ioc", "kill_process", "isolate_segment", "terminate_pid", "create_firewall_rule", "quarantine_file", "request_human_approval"}, "report": {"submit_containment_plan"}, } PHASE_VIOLATION = "PHASE_VIOLATION" INVALID_PARAMS = "INVALID_PARAMS" UNGROUNDED_ACTION = "UNGROUNDED_ACTION" _REQUIRED_ARGS = { "block_ioc": ["ioc_value"], "kill_process": ["hostname", "process_name"], "isolate_segment": ["target"], "correlate_alerts": ["alert_ids"], "enrich_ioc": ["ioc_value", "ioc_type"], "scan_host_vulnerabilities": ["hostname"], "terminate_pid": ["hostname", "pid"], "create_firewall_rule": ["hostname", "target_ip", "action"], "quarantine_file": ["hostname", "file_path"], } class ActionValidationMiddleware: def validate( self, phase: str, tool_name: str, arguments: dict, graph: "ThreatGraph", ) -> Optional[dict]: # Gate 1 — Phase whitelist allowed = PHASE_TOOL_WHITELIST.get(phase, set()) if tool_name not in allowed: return { "error": PHASE_VIOLATION, "message": ( f"Tool '{tool_name}' is not allowed in phase '{phase}'. " f"Allowed tools: {sorted(allowed)}" ), "retry": False, } # Gate 2 — Argument presence (basic schema check) required = _REQUIRED_ARGS.get(tool_name, []) for arg in required: if arg not in arguments: return { "error": INVALID_PARAMS, "message": f"Missing required argument '{arg}' for tool '{tool_name}'", "retry": True, } if tool_name == "correlate_alerts": ids = arguments.get("alert_ids", []) if not isinstance(ids, (list, tuple)) or len(ids) < 2: return { "error": INVALID_PARAMS, "message": "correlate_alerts requires at least 2 alert_ids", "retry": True, } # Gate 3 — Graph groundedness if tool_name == "block_ioc": if arguments["ioc_value"] not in graph.iocs: return { "error": UNGROUNDED_ACTION, "message": "IOC not in Threat Graph. Run investigation first.", "retry": True, } elif tool_name == "kill_process": key = f"{arguments['hostname']}:{arguments['process_name']}" if key not in graph.processes: return { "error": UNGROUNDED_ACTION, "message": "Process not in Threat Graph.", "retry": True, } elif tool_name == "enrich_ioc": if arguments["ioc_value"] not in graph.iocs: return { "error": UNGROUNDED_ACTION, "message": "IOC not known. Discover it during investigation first.", "retry": True, } return None