File size: 2,793 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from nl2sql.safety import Safety
import pytest



def test_safety_allows_select():
    s = Safety()
    result = s.check("SELECT * FROM users;")
    assert result.ok
    assert "sql" in result.data
    assert result.trace.stage == "safety"

def test_safety_allows_with_select_cte():
    s = Safety()
    sql = """
    WITH recent AS (
      SELECT id FROM users WHERE created_at > '2024-01-01'
    )
    SELECT * FROM users u JOIN recent r ON u.id = r.id;
    """
    r = s.check(sql)
    assert r.ok

def test_safety_allows_select_with_comments_and_newlines():
    s = Safety()
    sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
    r = s.check(sql)
    assert r.ok

def test_safety_allows_keywords_inside_string_literals():
    s = Safety()
    sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
    r = s.check(sql)
    assert r.ok, r.error


def test_safety_blocks_delete():
    s = Safety()
    result = s.check("DELETE FROM users;")
    assert not result.ok
    assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))

@pytest.mark.parametrize("sql", [
    "UPDATE users SET name='X' WHERE id=1;",
    "INSERT INTO users(id) VALUES (1);",
    "DROP TABLE users;",
    "CREATE TABLE x(id INT);",
    "ALTER TABLE users ADD COLUMN x INT;",
    "ATTACH DATABASE 'hack.db' AS h;",
    "PRAGMA journal_mode=WAL;",
])
def test_safety_blocks_forbidden_statements(sql):
    s = Safety()
    res = s.check(sql)
    assert not res.ok

def test_safety_blocks_stacked_delete_after_select():
    s = Safety()
    sql = "SELECT * FROM users; DELETE FROM users;"
    r = s.check(sql)
    assert not r.ok

def test_safety_blocks_stacked_delete_with_spaces():
    s = Safety()
    sql = "SELECT * FROM users ;   \n  DELETE users;"
    r = s.check(sql)
    assert not r.ok

def test_safety_blocks_delete_inside_cte():
    s = Safety()
    sql = """
    WITH bad AS (DELETE FROM users)
    SELECT * FROM users;
    """
    r = s.check(sql)
    assert not r.ok

@pytest.mark.parametrize("sql", [
    "/*D*/ROP TABLE users;",
    "PR/*x*/AGMA journal_mode=WAL;",
    "AL/* comment */TER TABLE x ADD COLUMN y INT;",
])
def test_safety_blocks_comment_obfuscation(sql):
    s = Safety()
    r = s.check(sql)
    assert not r.ok

@pytest.mark.parametrize("sql", [
    "pragma journal_mode=WAL;",  # lower-case
    "  PRAGMA  user_version = 5 ; ",
    "\nATTACH DATABASE 'hack.db' AS h;",
])
def test_safety_blocks_forbidden_case_and_spacing(sql):
    s = Safety()
    r = s.check(sql)
    assert not r.ok

def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
    s = Safety()
    sql = "SELECT 1;  -- now do something bad\n"
    sql_bad = "SELECT 1;  /* spacer */  DROP TABLE x;"
    assert s.check(sql).ok
    assert not s.check(sql_bad).ok