File size: 3,744 Bytes
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7211e63
 
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7211e63
 
 
57e71f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""3-gate action validation middleware: phase whitelist + schema + graph groundedness."""

from __future__ import annotations

from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from .threat_graph import ThreatGraph


PHASE_TOOL_WHITELIST = {
    "triage":        {"read_alerts", "read_topology", "correlate_alerts"},
    "investigation": {"query_host", "run_forensics", "add_ioc",
                      "enrich_ioc", "scan_host_vulnerabilities"},
    "remediation":   {"block_ioc", "kill_process", "isolate_segment",
                      "terminate_pid", "create_firewall_rule", "quarantine_file",
                      "request_human_approval"},
    "report":        {"submit_containment_plan"},
}

PHASE_VIOLATION = "PHASE_VIOLATION"
INVALID_PARAMS = "INVALID_PARAMS"
UNGROUNDED_ACTION = "UNGROUNDED_ACTION"


_REQUIRED_ARGS = {
    "block_ioc":                 ["ioc_value"],
    "kill_process":              ["hostname", "process_name"],
    "isolate_segment":           ["target"],
    "correlate_alerts":          ["alert_ids"],
    "enrich_ioc":                ["ioc_value", "ioc_type"],
    "scan_host_vulnerabilities": ["hostname"],
    "terminate_pid":             ["hostname", "pid"],
    "create_firewall_rule":      ["hostname", "target_ip", "action"],
    "quarantine_file":           ["hostname", "file_path"],
}


class ActionValidationMiddleware:

    def validate(
        self,
        phase: str,
        tool_name: str,
        arguments: dict,
        graph: "ThreatGraph",
    ) -> Optional[dict]:
        # Gate 1 — Phase whitelist
        allowed = PHASE_TOOL_WHITELIST.get(phase, set())
        if tool_name not in allowed:
            return {
                "error": PHASE_VIOLATION,
                "message": (
                    f"Tool '{tool_name}' is not allowed in phase '{phase}'. "
                    f"Allowed tools: {sorted(allowed)}"
                ),
                "retry": False,
            }

        # Gate 2 — Argument presence (basic schema check)
        required = _REQUIRED_ARGS.get(tool_name, [])
        for arg in required:
            if arg not in arguments:
                return {
                    "error": INVALID_PARAMS,
                    "message": f"Missing required argument '{arg}' for tool '{tool_name}'",
                    "retry": True,
                }
        if tool_name == "correlate_alerts":
            ids = arguments.get("alert_ids", [])
            if not isinstance(ids, (list, tuple)) or len(ids) < 2:
                return {
                    "error": INVALID_PARAMS,
                    "message": "correlate_alerts requires at least 2 alert_ids",
                    "retry": True,
                }

        # Gate 3 — Graph groundedness
        if tool_name == "block_ioc":
            if arguments["ioc_value"] not in graph.iocs:
                return {
                    "error": UNGROUNDED_ACTION,
                    "message": "IOC not in Threat Graph. Run investigation first.",
                    "retry": True,
                }
        elif tool_name == "kill_process":
            key = f"{arguments['hostname']}:{arguments['process_name']}"
            if key not in graph.processes:
                return {
                    "error": UNGROUNDED_ACTION,
                    "message": "Process not in Threat Graph.",
                    "retry": True,
                }
        elif tool_name == "enrich_ioc":
            if arguments["ioc_value"] not in graph.iocs:
                return {
                    "error": UNGROUNDED_ACTION,
                    "message": "IOC not known. Discover it during investigation first.",
                    "retry": True,
                }

        return None