| """Tests for Task 6 — Action Validation Middleware.""" |
|
|
| import os |
| import sys |
|
|
| _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) |
| if _PROJECT_ROOT not in sys.path: |
| sys.path.insert(0, _PROJECT_ROOT) |
|
|
| from server.action_validation import ( |
| ActionValidationMiddleware, |
| PHASE_VIOLATION, |
| INVALID_PARAMS, |
| UNGROUNDED_ACTION, |
| ) |
| from server.threat_graph import ThreatGraph, IOCNode, ProcessNode |
|
|
|
|
| def _empty(): |
| return ThreatGraph() |
|
|
|
|
| def _graph_with_ioc(value="1.2.3.4"): |
| g = ThreatGraph() |
| g.add_ioc(IOCNode(ioc_value=value, ioc_type="ip", confidence=0.9)) |
| return g |
|
|
|
|
| def _graph_with_process(host="H", proc_name="evil.exe"): |
| g = ThreatGraph() |
| g.add_process(ProcessNode(process_id=f"{host}:{proc_name}", hostname=host, process_name=proc_name)) |
| return g |
|
|
|
|
| def test_gate1_rejects_wrong_phase(): |
| m = ActionValidationMiddleware() |
| err = m.validate("triage", "kill_process", {}, _empty()) |
| assert err is not None |
| assert err["error"] == PHASE_VIOLATION |
|
|
|
|
| def test_gate1_passes_correct_phase(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "kill_process", {"hostname": "H", "process_name": "evil.exe"}, _graph_with_process()) |
| assert err is None |
|
|
|
|
| def test_gate1_error_lists_allowed_tools(): |
| m = ActionValidationMiddleware() |
| err = m.validate("triage", "kill_process", {}, _empty()) |
| assert err is not None |
| msg = err["message"].lower() |
| assert any(t in msg for t in ["read_alerts", "read_topology", "correlate_alerts"]) |
|
|
|
|
| def test_gate2_rejects_missing_ioc_value(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "block_ioc", {}, _empty()) |
| assert err is not None and err["error"] == INVALID_PARAMS |
|
|
|
|
| def test_gate2_rejects_correlate_with_one_alert(): |
| m = ActionValidationMiddleware() |
| err = m.validate("triage", "correlate_alerts", {"alert_ids": ["A1"]}, _empty()) |
| assert err is not None and err["error"] == INVALID_PARAMS |
|
|
|
|
| def test_gate3_rejects_ungrounded_block(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "block_ioc", {"ioc_value": "1.2.3.4"}, _empty()) |
| assert err is not None and err["error"] == UNGROUNDED_ACTION |
|
|
|
|
| def test_gate3_passes_grounded_block(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "block_ioc", {"ioc_value": "1.2.3.4"}, _graph_with_ioc()) |
| assert err is None |
|
|
|
|
| def test_gate3_rejects_ungrounded_kill(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "kill_process", {"hostname": "H", "process_name": "unknown.exe"}, _empty()) |
| assert err is not None and err["error"] == UNGROUNDED_ACTION |
|
|
|
|
| def test_all_gates_pass_returns_none(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "block_ioc", {"ioc_value": "1.2.3.4"}, _graph_with_ioc()) |
| assert err is None |
|
|
|
|
| def test_retry_flag_false_for_phase_violation(): |
| m = ActionValidationMiddleware() |
| err = m.validate("triage", "kill_process", {}, _empty()) |
| assert err["retry"] is False |
|
|
|
|
| def test_retry_flag_true_for_invalid_params(): |
| m = ActionValidationMiddleware() |
| err = m.validate("remediation", "block_ioc", {}, _empty()) |
| assert err["retry"] is True |
|
|