File size: 3,818 Bytes
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
tests/unit/test_sql_sandbox.py
Tests for sqlglot-based SQL safety validator.
All tests are pure Python β€” no external calls.
"""

import pytest
from sandbox.sql_sandbox import validate_sql


# ── Safe queries (must pass) ──────────────────────────────────────────────────

SAFE_QUERIES = [
    "SELECT * FROM orders LIMIT 10",
    "SELECT product, SUM(amount) FROM orders GROUP BY product",
    "SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id",
    "WITH cte AS (SELECT * FROM orders) SELECT * FROM cte",
    "SELECT COUNT(*) FROM orders WHERE created_at > '2024-01-01'",
    "SELECT DISTINCT region FROM customers ORDER BY region",
    "SELECT product, AVG(amount) OVER (PARTITION BY region) FROM orders",
    "SELECT CAST(amount AS TEXT) FROM orders",
]

@pytest.mark.unit
@pytest.mark.parametrize("sql", SAFE_QUERIES)
def test_safe_queries_pass(sql):
    ok, err = validate_sql(sql)
    assert ok, f"Expected safe query to pass but got: {err}"
    assert err == ""


# ── Blocked write operations ──────────────────────────────────────────────────

BLOCKED_QUERIES = [
    ("DELETE FROM orders WHERE 1=1", "Delete"),
    ("DROP TABLE orders", "Drop"),
    ("UPDATE orders SET amount = 0", "Update"),
    ("INSERT INTO orders (id) VALUES (1)", "Insert"),
    ("ALTER TABLE orders ADD COLUMN foo TEXT", "AlterTable"),
    ("CREATE TABLE evil (id INT)", "Create"),
    ("GRANT ALL ON orders TO attacker", "Grant"),
]

@pytest.mark.unit
@pytest.mark.parametrize("sql,expected_type", BLOCKED_QUERIES)
def test_blocked_write_operations(sql, expected_type):
    ok, err = validate_sql(sql)
    assert not ok, f"Expected '{sql}' to be blocked"
    assert err != ""

@pytest.mark.unit
def test_truncate_blocked():
    """TRUNCATE may be parsed as its own node type by newer sqlglot versions."""
    ok, err = validate_sql("TRUNCATE orders")
    # TRUNCATE must be blocked regardless of how sqlglot classifies it
    assert not ok, "TRUNCATE should be blocked"


# ── Blocked dangerous functions ───────────────────────────────────────────────

@pytest.mark.unit
def test_blocks_pg_read_file():
    sql = "SELECT pg_read_file('/etc/passwd')"
    ok, err = validate_sql(sql)
    assert not ok
    assert "pg_read_file" in err.lower() or err != ""


# ── Edge cases ────────────────────────────────────────────────────────────────

@pytest.mark.unit
def test_empty_query_rejected():
    ok, err = validate_sql("")
    assert not ok
    assert err != ""


@pytest.mark.unit
def test_whitespace_only_rejected():
    ok, err = validate_sql("   \n\t  ")
    assert not ok


@pytest.mark.unit
def test_sqlite_dialect():
    ok, err = validate_sql("SELECT strftime('%Y', created_at) FROM orders", dialect="sqlite")
    assert ok, f"SQLite date function should pass: {err}"


@pytest.mark.unit
def test_cte_with_embedded_delete_blocked():
    """DELETE inside a CTE must still be blocked."""
    sql = """
    WITH bad AS (DELETE FROM orders RETURNING *)
    SELECT * FROM bad
    """
    ok, err = validate_sql(sql)
    assert not ok


@pytest.mark.unit
def test_multiple_safe_statements_pass():
    """A single SELECT is safe."""
    ok, err = validate_sql("SELECT 1")
    assert ok


@pytest.mark.unit
def test_syntax_error_returns_false():
    ok, err = validate_sql("SELECT FROM WHERE")
    assert not ok
    assert "parse" in err.lower() or err != ""