| """ |
| tests/unit/test_sql_sandbox.py |
| Tests for sqlglot-based SQL safety validator. |
| All tests are pure Python β no external calls. |
| """ |
|
|
| import pytest |
| from sandbox.sql_sandbox import validate_sql |
|
|
|
|
| |
|
|
| SAFE_QUERIES = [ |
| "SELECT * FROM orders LIMIT 10", |
| "SELECT product, SUM(amount) FROM orders GROUP BY product", |
| "SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id", |
| "WITH cte AS (SELECT * FROM orders) SELECT * FROM cte", |
| "SELECT COUNT(*) FROM orders WHERE created_at > '2024-01-01'", |
| "SELECT DISTINCT region FROM customers ORDER BY region", |
| "SELECT product, AVG(amount) OVER (PARTITION BY region) FROM orders", |
| "SELECT CAST(amount AS TEXT) FROM orders", |
| ] |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize("sql", SAFE_QUERIES) |
| def test_safe_queries_pass(sql): |
| ok, err = validate_sql(sql) |
| assert ok, f"Expected safe query to pass but got: {err}" |
| assert err == "" |
|
|
|
|
| |
|
|
| BLOCKED_QUERIES = [ |
| ("DELETE FROM orders WHERE 1=1", "Delete"), |
| ("DROP TABLE orders", "Drop"), |
| ("UPDATE orders SET amount = 0", "Update"), |
| ("INSERT INTO orders (id) VALUES (1)", "Insert"), |
| ("ALTER TABLE orders ADD COLUMN foo TEXT", "AlterTable"), |
| ("CREATE TABLE evil (id INT)", "Create"), |
| ("GRANT ALL ON orders TO attacker", "Grant"), |
| ] |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize("sql,expected_type", BLOCKED_QUERIES) |
| def test_blocked_write_operations(sql, expected_type): |
| ok, err = validate_sql(sql) |
| assert not ok, f"Expected '{sql}' to be blocked" |
| assert err != "" |
|
|
| @pytest.mark.unit |
| def test_truncate_blocked(): |
| """TRUNCATE may be parsed as its own node type by newer sqlglot versions.""" |
| ok, err = validate_sql("TRUNCATE orders") |
| |
| assert not ok, "TRUNCATE should be blocked" |
|
|
|
|
| |
|
|
| @pytest.mark.unit |
| def test_blocks_pg_read_file(): |
| sql = "SELECT pg_read_file('/etc/passwd')" |
| ok, err = validate_sql(sql) |
| assert not ok |
| assert "pg_read_file" in err.lower() or err != "" |
|
|
|
|
| |
|
|
| @pytest.mark.unit |
| def test_empty_query_rejected(): |
| ok, err = validate_sql("") |
| assert not ok |
| assert err != "" |
|
|
|
|
| @pytest.mark.unit |
| def test_whitespace_only_rejected(): |
| ok, err = validate_sql(" \n\t ") |
| assert not ok |
|
|
|
|
| @pytest.mark.unit |
| def test_sqlite_dialect(): |
| ok, err = validate_sql("SELECT strftime('%Y', created_at) FROM orders", dialect="sqlite") |
| assert ok, f"SQLite date function should pass: {err}" |
|
|
|
|
| @pytest.mark.unit |
| def test_cte_with_embedded_delete_blocked(): |
| """DELETE inside a CTE must still be blocked.""" |
| sql = """ |
| WITH bad AS (DELETE FROM orders RETURNING *) |
| SELECT * FROM bad |
| """ |
| ok, err = validate_sql(sql) |
| assert not ok |
|
|
|
|
| @pytest.mark.unit |
| def test_multiple_safe_statements_pass(): |
| """A single SELECT is safe.""" |
| ok, err = validate_sql("SELECT 1") |
| assert ok |
|
|
|
|
| @pytest.mark.unit |
| def test_syntax_error_returns_false(): |
| ok, err = validate_sql("SELECT FROM WHERE") |
| assert not ok |
| assert "parse" in err.lower() or err != "" |
|
|