nl2sql-copilot / tests /test_safety.py
Melika Kheirieh
style: format code with ruff
c1bc4eb
raw
history blame
2.9 kB
from nl2sql.safety import Safety
import pytest
def test_safety_allows_select():
s = Safety()
result = s.check("SELECT * FROM users;")
assert result.ok
assert "sql" in result.data
assert result.trace.stage == "safety"
def test_safety_allows_with_select_cte():
s = Safety()
sql = """
WITH recent AS (
SELECT id FROM users WHERE created_at > '2024-01-01'
)
SELECT * FROM users u JOIN recent r ON u.id = r.id;
"""
r = s.check(sql)
assert r.ok
def test_safety_allows_select_with_comments_and_newlines():
s = Safety()
sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
r = s.check(sql)
assert r.ok
def test_safety_allows_keywords_inside_string_literals():
s = Safety()
sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
r = s.check(sql)
assert r.ok, r.error
def test_safety_blocks_delete():
s = Safety()
result = s.check("DELETE FROM users;")
assert not result.ok
assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
@pytest.mark.parametrize(
"sql",
[
"UPDATE users SET name='X' WHERE id=1;",
"INSERT INTO users(id) VALUES (1);",
"DROP TABLE users;",
"CREATE TABLE x(id INT);",
"ALTER TABLE users ADD COLUMN x INT;",
"ATTACH DATABASE 'hack.db' AS h;",
"PRAGMA journal_mode=WAL;",
],
)
def test_safety_blocks_forbidden_statements(sql):
s = Safety()
res = s.check(sql)
assert not res.ok
def test_safety_blocks_stacked_delete_after_select():
s = Safety()
sql = "SELECT * FROM users; DELETE FROM users;"
r = s.check(sql)
assert not r.ok
def test_safety_blocks_stacked_delete_with_spaces():
s = Safety()
sql = "SELECT * FROM users ; \n DELETE users;"
r = s.check(sql)
assert not r.ok
def test_safety_blocks_delete_inside_cte():
s = Safety()
sql = """
WITH bad AS (DELETE FROM users)
SELECT * FROM users;
"""
r = s.check(sql)
assert not r.ok
@pytest.mark.parametrize(
"sql",
[
"/*D*/ROP TABLE users;",
"PR/*x*/AGMA journal_mode=WAL;",
"AL/* comment */TER TABLE x ADD COLUMN y INT;",
],
)
def test_safety_blocks_comment_obfuscation(sql):
s = Safety()
r = s.check(sql)
assert not r.ok
@pytest.mark.parametrize(
"sql",
[
"pragma journal_mode=WAL;", # lower-case
" PRAGMA user_version = 5 ; ",
"\nATTACH DATABASE 'hack.db' AS h;",
],
)
def test_safety_blocks_forbidden_case_and_spacing(sql):
s = Safety()
r = s.check(sql)
assert not r.ok
def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
s = Safety()
sql = "SELECT 1; -- now do something bad\n"
sql_bad = "SELECT 1; /* spacer */ DROP TABLE x;"
assert s.check(sql).ok
assert not s.check(sql_bad).ok