| """3-gate action validation middleware: phase whitelist + schema + graph groundedness.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional, TYPE_CHECKING |
|
|
| if TYPE_CHECKING: |
| from .threat_graph import ThreatGraph |
|
|
|
|
| PHASE_TOOL_WHITELIST = { |
| "triage": {"read_alerts", "read_topology", "correlate_alerts"}, |
| "investigation": {"query_host", "run_forensics", "add_ioc", |
| "enrich_ioc", "scan_host_vulnerabilities"}, |
| "remediation": {"block_ioc", "kill_process", "isolate_segment", |
| "terminate_pid", "create_firewall_rule", "quarantine_file", |
| "request_human_approval"}, |
| "report": {"submit_containment_plan"}, |
| } |
|
|
| PHASE_VIOLATION = "PHASE_VIOLATION" |
| INVALID_PARAMS = "INVALID_PARAMS" |
| UNGROUNDED_ACTION = "UNGROUNDED_ACTION" |
|
|
|
|
| _REQUIRED_ARGS = { |
| "block_ioc": ["ioc_value"], |
| "kill_process": ["hostname", "process_name"], |
| "isolate_segment": ["target"], |
| "correlate_alerts": ["alert_ids"], |
| "enrich_ioc": ["ioc_value", "ioc_type"], |
| "scan_host_vulnerabilities": ["hostname"], |
| "terminate_pid": ["hostname", "pid"], |
| "create_firewall_rule": ["hostname", "target_ip", "action"], |
| "quarantine_file": ["hostname", "file_path"], |
| } |
|
|
|
|
| class ActionValidationMiddleware: |
|
|
| def validate( |
| self, |
| phase: str, |
| tool_name: str, |
| arguments: dict, |
| graph: "ThreatGraph", |
| ) -> Optional[dict]: |
| |
| allowed = PHASE_TOOL_WHITELIST.get(phase, set()) |
| if tool_name not in allowed: |
| return { |
| "error": PHASE_VIOLATION, |
| "message": ( |
| f"Tool '{tool_name}' is not allowed in phase '{phase}'. " |
| f"Allowed tools: {sorted(allowed)}" |
| ), |
| "retry": False, |
| } |
|
|
| |
| required = _REQUIRED_ARGS.get(tool_name, []) |
| for arg in required: |
| if arg not in arguments: |
| return { |
| "error": INVALID_PARAMS, |
| "message": f"Missing required argument '{arg}' for tool '{tool_name}'", |
| "retry": True, |
| } |
| if tool_name == "correlate_alerts": |
| ids = arguments.get("alert_ids", []) |
| if not isinstance(ids, (list, tuple)) or len(ids) < 2: |
| return { |
| "error": INVALID_PARAMS, |
| "message": "correlate_alerts requires at least 2 alert_ids", |
| "retry": True, |
| } |
|
|
| |
| if tool_name == "block_ioc": |
| if arguments["ioc_value"] not in graph.iocs: |
| return { |
| "error": UNGROUNDED_ACTION, |
| "message": "IOC not in Threat Graph. Run investigation first.", |
| "retry": True, |
| } |
| elif tool_name == "kill_process": |
| key = f"{arguments['hostname']}:{arguments['process_name']}" |
| if key not in graph.processes: |
| return { |
| "error": UNGROUNDED_ACTION, |
| "message": "Process not in Threat Graph.", |
| "retry": True, |
| } |
| elif tool_name == "enrich_ioc": |
| if arguments["ioc_value"] not in graph.iocs: |
| return { |
| "error": UNGROUNDED_ACTION, |
| "message": "IOC not known. Discover it during investigation first.", |
| "retry": True, |
| } |
|
|
| return None |
|
|