File size: 3,818 Bytes
abd4352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | """
tests/unit/test_sql_sandbox.py
Tests for sqlglot-based SQL safety validator.
All tests are pure Python β no external calls.
"""
import pytest
from sandbox.sql_sandbox import validate_sql
# ββ Safe queries (must pass) ββββββββββββββββββββββββββββββββββββββββββββββββββ
SAFE_QUERIES = [
"SELECT * FROM orders LIMIT 10",
"SELECT product, SUM(amount) FROM orders GROUP BY product",
"SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id",
"WITH cte AS (SELECT * FROM orders) SELECT * FROM cte",
"SELECT COUNT(*) FROM orders WHERE created_at > '2024-01-01'",
"SELECT DISTINCT region FROM customers ORDER BY region",
"SELECT product, AVG(amount) OVER (PARTITION BY region) FROM orders",
"SELECT CAST(amount AS TEXT) FROM orders",
]
@pytest.mark.unit
@pytest.mark.parametrize("sql", SAFE_QUERIES)
def test_safe_queries_pass(sql):
ok, err = validate_sql(sql)
assert ok, f"Expected safe query to pass but got: {err}"
assert err == ""
# ββ Blocked write operations ββββββββββββββββββββββββββββββββββββββββββββββββββ
BLOCKED_QUERIES = [
("DELETE FROM orders WHERE 1=1", "Delete"),
("DROP TABLE orders", "Drop"),
("UPDATE orders SET amount = 0", "Update"),
("INSERT INTO orders (id) VALUES (1)", "Insert"),
("ALTER TABLE orders ADD COLUMN foo TEXT", "AlterTable"),
("CREATE TABLE evil (id INT)", "Create"),
("GRANT ALL ON orders TO attacker", "Grant"),
]
@pytest.mark.unit
@pytest.mark.parametrize("sql,expected_type", BLOCKED_QUERIES)
def test_blocked_write_operations(sql, expected_type):
ok, err = validate_sql(sql)
assert not ok, f"Expected '{sql}' to be blocked"
assert err != ""
@pytest.mark.unit
def test_truncate_blocked():
"""TRUNCATE may be parsed as its own node type by newer sqlglot versions."""
ok, err = validate_sql("TRUNCATE orders")
# TRUNCATE must be blocked regardless of how sqlglot classifies it
assert not ok, "TRUNCATE should be blocked"
# ββ Blocked dangerous functions βββββββββββββββββββββββββββββββββββββββββββββββ
@pytest.mark.unit
def test_blocks_pg_read_file():
sql = "SELECT pg_read_file('/etc/passwd')"
ok, err = validate_sql(sql)
assert not ok
assert "pg_read_file" in err.lower() or err != ""
# ββ Edge cases ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@pytest.mark.unit
def test_empty_query_rejected():
ok, err = validate_sql("")
assert not ok
assert err != ""
@pytest.mark.unit
def test_whitespace_only_rejected():
ok, err = validate_sql(" \n\t ")
assert not ok
@pytest.mark.unit
def test_sqlite_dialect():
ok, err = validate_sql("SELECT strftime('%Y', created_at) FROM orders", dialect="sqlite")
assert ok, f"SQLite date function should pass: {err}"
@pytest.mark.unit
def test_cte_with_embedded_delete_blocked():
"""DELETE inside a CTE must still be blocked."""
sql = """
WITH bad AS (DELETE FROM orders RETURNING *)
SELECT * FROM bad
"""
ok, err = validate_sql(sql)
assert not ok
@pytest.mark.unit
def test_multiple_safe_statements_pass():
"""A single SELECT is safe."""
ok, err = validate_sql("SELECT 1")
assert ok
@pytest.mark.unit
def test_syntax_error_returns_false():
ok, err = validate_sql("SELECT FROM WHERE")
assert not ok
assert "parse" in err.lower() or err != ""
|