"""Tests for Task 6 — Action Validation Middleware.""" import os import sys _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) from server.action_validation import ( ActionValidationMiddleware, PHASE_VIOLATION, INVALID_PARAMS, UNGROUNDED_ACTION, ) from server.threat_graph import ThreatGraph, IOCNode, ProcessNode def _empty(): return ThreatGraph() def _graph_with_ioc(value="1.2.3.4"): g = ThreatGraph() g.add_ioc(IOCNode(ioc_value=value, ioc_type="ip", confidence=0.9)) return g def _graph_with_process(host="H", proc_name="evil.exe"): g = ThreatGraph() g.add_process(ProcessNode(process_id=f"{host}:{proc_name}", hostname=host, process_name=proc_name)) return g def test_gate1_rejects_wrong_phase(): m = ActionValidationMiddleware() err = m.validate("triage", "kill_process", {}, _empty()) assert err is not None assert err["error"] == PHASE_VIOLATION def test_gate1_passes_correct_phase(): m = ActionValidationMiddleware() err = m.validate("remediation", "kill_process", {"hostname": "H", "process_name": "evil.exe"}, _graph_with_process()) assert err is None def test_gate1_error_lists_allowed_tools(): m = ActionValidationMiddleware() err = m.validate("triage", "kill_process", {}, _empty()) assert err is not None msg = err["message"].lower() assert any(t in msg for t in ["read_alerts", "read_topology", "correlate_alerts"]) def test_gate2_rejects_missing_ioc_value(): m = ActionValidationMiddleware() err = m.validate("remediation", "block_ioc", {}, _empty()) assert err is not None and err["error"] == INVALID_PARAMS def test_gate2_rejects_correlate_with_one_alert(): m = ActionValidationMiddleware() err = m.validate("triage", "correlate_alerts", {"alert_ids": ["A1"]}, _empty()) assert err is not None and err["error"] == INVALID_PARAMS def test_gate3_rejects_ungrounded_block(): m = ActionValidationMiddleware() err = m.validate("remediation", "block_ioc", {"ioc_value": "1.2.3.4"}, _empty()) assert err is not None and err["error"] == UNGROUNDED_ACTION def test_gate3_passes_grounded_block(): m = ActionValidationMiddleware() err = m.validate("remediation", "block_ioc", {"ioc_value": "1.2.3.4"}, _graph_with_ioc()) assert err is None def test_gate3_rejects_ungrounded_kill(): m = ActionValidationMiddleware() err = m.validate("remediation", "kill_process", {"hostname": "H", "process_name": "unknown.exe"}, _empty()) assert err is not None and err["error"] == UNGROUNDED_ACTION def test_all_gates_pass_returns_none(): m = ActionValidationMiddleware() err = m.validate("remediation", "block_ioc", {"ioc_value": "1.2.3.4"}, _graph_with_ioc()) assert err is None def test_retry_flag_false_for_phase_violation(): m = ActionValidationMiddleware() err = m.validate("triage", "kill_process", {}, _empty()) assert err["retry"] is False def test_retry_flag_true_for_invalid_params(): m = ActionValidationMiddleware() err = m.validate("remediation", "block_ioc", {}, _empty()) assert err["retry"] is True