File size: 7,507 Bytes
cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """Π’Π΅ΡΡΡ Π½Π° ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΡ SQL ΠΈ ΡΠ²ΡΠ·Π°Π½Π½ΡΠ΅ ΡΡΠ½ΠΊΡΠΈΠΈ.
ΠΠΎΠΊΡΡΠ²Π°Π΅Ρ ΡΠ°Π·Π΄Π΅Π» 2.5 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ: ΡΠΈΡΡΠΊΡ Π°ΡΡΠ΅ΡΠ°ΠΊΡΠΎΠ²,
Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΡ ΡΠ΅ΡΠ΅Π· sqlglot, Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·Π°ΡΠΈΡ Π΄Π»Ρ Exact Match ΠΈ AST-ΡΡΠΎΠ²Π½Π΅Π²ΡΠΉ
Π³Π²Π°ΡΠ΄Π΅ΠΉΠ» is_select_only.
"""
from src.models.postprocess import (
is_select_only,
is_valid_sql,
normalize_sql,
postprocess,
strip_model_artifacts,
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# strip_model_artifacts
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_strip_markdown_block_with_lang():
raw = "```sql\nSELECT * FROM users;\n```"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_markdown_block_without_lang():
raw = "```\nSELECT id FROM t;\n```"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_sql_prefix():
raw = "SQL: SELECT 1;"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_russian_prefix():
raw = "ΠΡΠ²Π΅Ρ: SELECT name FROM students;"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_natural_language_before_select():
raw = "ΠΠΎΡ SQL, ΠΊΠΎΡΠΎΡΡΠΉ ΠΎΡΠ²Π΅ΡΠ°Π΅Ρ Π½Π° Π²ΠΎΠΏΡΠΎΡ: SELECT * FROM t WHERE id = 1;"
out = strip_model_artifacts(raw)
assert out.upper().startswith("SELECT")
assert "ΠΠΎΡ" not in out
def test_keeps_first_statement_of_two():
raw = "SELECT 1; SELECT 2;"
out = strip_model_artifacts(raw)
assert "SELECT 1" in out
assert "SELECT 2" not in out
def test_with_cte_is_preserved():
raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
out = strip_model_artifacts(raw)
assert out.upper().startswith("WITH")
def test_strip_returns_empty_on_garbage():
# ΠΠ΅Ρ Π½ΠΈ ΠΎΠ΄Π½ΠΎΠ³ΠΎ SQL-ΠΊΠ»ΡΡΠ΅Π²ΠΎΠ³ΠΎ ΡΠ»ΠΎΠ²Π° β ΠΎΠ±ΡΠ΅Π·Π°ΡΡ Π½Π΅ΡΠ΅Π³ΠΎ, Π½ΠΎ ΠΈ ΠΏΡΡΡΠΎΠ³ΠΎ
# ΠΎΡΠ²Π΅ΡΠ° ΠΌΠΎΠ΄Π΅Π»Ρ Π΅ΡΡ Π½Π΅ Π½Π°Π³Π΅Π½Π΅ΡΠΈΠ»Π°: Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅ΠΌ ΠΊΠ°ΠΊ Π΅ΡΡΡ, Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΡ
# ΠΎΡΡΠ΅Π΅Ρ Π΄Π°Π»ΡΡΠ΅ ΠΏΠΎ ΠΏΠ°ΠΉΠΏΠ»Π°ΠΉΠ½Ρ.
raw = "ΠΏΡΠΎΡΡΠΎ ΡΠ΅ΠΊΡΡ Π±Π΅Π· Π·Π°ΠΏΡΠΎΡΠ°"
assert strip_model_artifacts(raw) == "ΠΏΡΠΎΡΡΠΎ ΡΠ΅ΠΊΡΡ Π±Π΅Π· Π·Π°ΠΏΡΠΎΡΠ°"
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# is_valid_sql
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_valid_select():
assert is_valid_sql("SELECT * FROM students WHERE id = 1")
def test_valid_with_cte():
assert is_valid_sql("WITH x AS (SELECT id FROM t) SELECT * FROM x")
def test_invalid_garbage():
assert not is_valid_sql("SELEC * FRM where")
def test_invalid_empty():
assert not is_valid_sql("")
assert not is_valid_sql(" ")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# is_select_only β guardrail
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_select_passes_guardrail():
assert is_select_only("SELECT id FROM t")
def test_with_cte_passes_guardrail():
assert is_select_only("WITH x AS (SELECT id FROM t) SELECT * FROM x")
def test_drop_table_blocked():
assert not is_select_only("DROP TABLE users")
def test_delete_blocked():
assert not is_select_only("DELETE FROM users WHERE id = 1")
def test_update_blocked():
assert not is_select_only("UPDATE users SET name = 'a' WHERE id = 1")
def test_insert_blocked():
assert not is_select_only("INSERT INTO users (id, name) VALUES (1, 'a')")
def test_empty_blocked():
assert not is_select_only("")
assert not is_select_only(" ")
def test_invalid_sql_blocked_by_guardrail():
# ΠΠ° Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΉ ΡΡΡΠΎΠΊΠ΅ is_select_only Π΄ΠΎΠ»ΠΆΠ΅Π½ ΡΠ΅ΡΡΠ½ΠΎ Π²ΠΎΠ·Π²ΡΠ°ΡΠ°ΡΡ False,
# Π° Π½Π΅ ΠΏΠ°Π΄Π°ΡΡ Ρ ΠΈΡΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ΠΌ.
assert not is_select_only("not a sql at all")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# normalize_sql
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_normalize_collapses_whitespace():
a = "SELECT * FROM Users"
b = "select * from users"
assert normalize_sql(a) == normalize_sql(b)
def test_normalize_idempotent():
sql = "SELECT id FROM t WHERE x = 1"
assert normalize_sql(normalize_sql(sql)) == normalize_sql(sql)
def test_normalize_fallback_on_invalid():
# ΠΠ° Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΌ SQL ΡΡΠ½ΠΊΡΠΈΡ Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½Π° ΠΏΠ°Π΄Π°ΡΡ β Π΄ΠΎΠ»ΠΆΠ΅Π½ ΡΡΠ°Π±ΠΎΡΠ°ΡΡ fallback.
out = normalize_sql("not really sql")
assert isinstance(out, str)
assert out.upper() == out # Π²Π΅ΡΡ
Π½ΠΈΠΉ ΡΠ΅Π³ΠΈΡΡΡ ΡΠΎΡ
ΡΠ°Π½ΡΠ½
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# postprocess β ΠΏΠΎΠ»Π½ΡΠΉ pipeline
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def test_postprocess_extracts_from_markdown():
raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
out = postprocess(raw)
assert out.upper().startswith("SELECT NAME") or out.startswith("SELECT name")
assert "SELECT 2" not in out
def test_postprocess_returns_empty_on_invalid():
# Π’Π΅ΠΊΡΡ Π½Π΅ ΡΠΎΠ΄Π΅ΡΠΆΠΈΡ Π²Π°Π»ΠΈΠ΄Π½ΠΎΠ³ΠΎ SQL β pipeline Π΄ΠΎΠ»ΠΆΠ΅Π½ Π²Π΅ΡΠ½ΡΡΡ ΠΏΡΡΡΡΡ ΡΡΡΠΎΠΊΡ,
# ΠΊΠ°ΠΊ ΠΎΠΏΠΈΡΠ°Π½ΠΎ Π² ΡΠ°Π·Π΄Π΅Π»Π΅ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ.
raw = "Π― Π½Π΅ ΠΌΠΎΠ³Ρ ΡΠ³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°ΡΡ SQL Π΄Π»Ρ ΡΡΠΎΠ³ΠΎ Π²ΠΎΠΏΡΠΎΡΠ°."
assert postprocess(raw) == ""
def test_postprocess_returns_empty_on_truncated():
# ΠΠΎΠ΄Π΅Π»Ρ ΠΎΠ±ΠΎΡΠ²Π°Π»Π° Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ Π½Π° ΡΠ΅ΡΠ΅Π΄ΠΈΠ½Π΅ Π·Π°ΠΏΡΠΎΡΠ° β Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΡΠΉ ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡ.
raw = "SELECT * FROM users WHERE"
assert postprocess(raw) == ""
def test_postprocess_keeps_valid_with_cte():
raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
out = postprocess(raw)
assert out.upper().startswith("WITH")
|