File size: 2,899 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
c1bc4eb
 
 
 
 
 
 
 
 
 
 
 
 
570f7bd
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
c1bc4eb
 
 
 
 
 
 
 
 
570f7bd
 
 
 
 
c1bc4eb
 
 
 
 
 
 
 
 
570f7bd
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from nl2sql.safety import Safety
import pytest


def test_safety_allows_select():
    s = Safety()
    result = s.check("SELECT * FROM users;")
    assert result.ok
    assert "sql" in result.data
    assert result.trace.stage == "safety"


def test_safety_allows_with_select_cte():
    s = Safety()
    sql = """
    WITH recent AS (
      SELECT id FROM users WHERE created_at > '2024-01-01'
    )
    SELECT * FROM users u JOIN recent r ON u.id = r.id;
    """
    r = s.check(sql)
    assert r.ok


def test_safety_allows_select_with_comments_and_newlines():
    s = Safety()
    sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
    r = s.check(sql)
    assert r.ok


def test_safety_allows_keywords_inside_string_literals():
    s = Safety()
    sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
    r = s.check(sql)
    assert r.ok, r.error


def test_safety_blocks_delete():
    s = Safety()
    result = s.check("DELETE FROM users;")
    assert not result.ok
    assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))


@pytest.mark.parametrize(
    "sql",
    [
        "UPDATE users SET name='X' WHERE id=1;",
        "INSERT INTO users(id) VALUES (1);",
        "DROP TABLE users;",
        "CREATE TABLE x(id INT);",
        "ALTER TABLE users ADD COLUMN x INT;",
        "ATTACH DATABASE 'hack.db' AS h;",
        "PRAGMA journal_mode=WAL;",
    ],
)
def test_safety_blocks_forbidden_statements(sql):
    s = Safety()
    res = s.check(sql)
    assert not res.ok


def test_safety_blocks_stacked_delete_after_select():
    s = Safety()
    sql = "SELECT * FROM users; DELETE FROM users;"
    r = s.check(sql)
    assert not r.ok


def test_safety_blocks_stacked_delete_with_spaces():
    s = Safety()
    sql = "SELECT * FROM users ;   \n  DELETE users;"
    r = s.check(sql)
    assert not r.ok


def test_safety_blocks_delete_inside_cte():
    s = Safety()
    sql = """
    WITH bad AS (DELETE FROM users)
    SELECT * FROM users;
    """
    r = s.check(sql)
    assert not r.ok


@pytest.mark.parametrize(
    "sql",
    [
        "/*D*/ROP TABLE users;",
        "PR/*x*/AGMA journal_mode=WAL;",
        "AL/* comment */TER TABLE x ADD COLUMN y INT;",
    ],
)
def test_safety_blocks_comment_obfuscation(sql):
    s = Safety()
    r = s.check(sql)
    assert not r.ok


@pytest.mark.parametrize(
    "sql",
    [
        "pragma journal_mode=WAL;",  # lower-case
        "  PRAGMA  user_version = 5 ; ",
        "\nATTACH DATABASE 'hack.db' AS h;",
    ],
)
def test_safety_blocks_forbidden_case_and_spacing(sql):
    s = Safety()
    r = s.check(sql)
    assert not r.ok


def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
    s = Safety()
    sql = "SELECT 1;  -- now do something bad\n"
    sql_bad = "SELECT 1;  /* spacer */  DROP TABLE x;"
    assert s.check(sql).ok
    assert not s.check(sql_bad).ok