feat: add schema-aware validation
Browse files- README.md +3 -3
- app.py +55 -10
- model_io.py +25 -10
- sql_tools.py +135 -2
- tests/test_chatbot_behavior.py +87 -0
- tests/test_chatbot_core.py +1 -1
README.md
CHANGED
|
@@ -63,7 +63,7 @@ The product layer exists because a fine-tuned model is not the same thing as a r
|
|
| 63 |
- Intent routing limits the model path to `SQL_QUERY`.
|
| 64 |
- Deterministic schema tools handle explicit create/edit requests without loading the model.
|
| 65 |
- High-confidence SQL templates handle simple ranking, aggregation, count, and comparison queries before the CPU model path.
|
| 66 |
-
- SQL output is validated with `sqlparse`
|
| 67 |
- Lazy loading keeps startup cheap; the model is downloaded and loaded only when needed.
|
| 68 |
- Load and generation timeouts protect the UI from indefinite waits.
|
| 69 |
- Static fallbacks make unsupported behavior visible instead of laundering it as AI.
|
|
@@ -80,7 +80,7 @@ The product layer exists because a fine-tuned model is not the same thing as a r
|
|
| 80 |
3. Enter the question in the chat input.
|
| 81 |
4. Click **Send**.
|
| 82 |
5. Review the result in `gr.Code(language="sql")` and the source/status message.
|
| 83 |
-
- The app shows a validation badge powered by `sqlparse`.
|
| 84 |
- Known-source errors should still identify their source path; unknown-source errors should not fake certainty.
|
| 85 |
|
| 86 |
## Usage Examples
|
|
@@ -127,7 +127,7 @@ The probe prints JSON with pass/fail checks for static fallback, deterministic C
|
|
| 127 |
set PYTHONPATH=. && pytest tests/test_chatbot_core.py tests/test_chatbot_behavior.py -q
|
| 128 |
```
|
| 129 |
|
| 130 |
-
Current unit suite: **
|
| 131 |
|
| 132 |
## Run Locally
|
| 133 |
|
|
|
|
| 63 |
- Intent routing limits the model path to `SQL_QUERY`.
|
| 64 |
- Deterministic schema tools handle explicit create/edit requests without loading the model.
|
| 65 |
- High-confidence SQL templates handle simple ranking, aggregation, count, and comparison queries before the CPU model path.
|
| 66 |
+
- SQL output is validated with `sqlparse`; model output is also checked against the active `CREATE TABLE` schema before it can be shown as accepted model SQL.
|
| 67 |
- Lazy loading keeps startup cheap; the model is downloaded and loaded only when needed.
|
| 68 |
- Load and generation timeouts protect the UI from indefinite waits.
|
| 69 |
- Static fallbacks make unsupported behavior visible instead of laundering it as AI.
|
|
|
|
| 80 |
3. Enter the question in the chat input.
|
| 81 |
4. Click **Send**.
|
| 82 |
5. Review the result in `gr.Code(language="sql")` and the source/status message.
|
| 83 |
+
- The app shows a validation badge powered by `sqlparse` plus active-schema checks for model output.
|
| 84 |
- Known-source errors should still identify their source path; unknown-source errors should not fake certainty.
|
| 85 |
|
| 86 |
## Usage Examples
|
|
|
|
| 127 |
set PYTHONPATH=. && pytest tests/test_chatbot_core.py tests/test_chatbot_behavior.py -q
|
| 128 |
```
|
| 129 |
|
| 130 |
+
Current unit suite: **136 tests**. These tests avoid loading the 3.8B model and focus on routing, deterministic tools, prompt construction, model-output rejection, active-schema validation, SQL validation, UI schema-context synchronization, and error handling.
|
| 131 |
|
| 132 |
## Run Locally
|
| 133 |
|
app.py
CHANGED
|
@@ -1270,15 +1270,24 @@ def generate_response(message, chat_history, active_schema, loaded_key, conversa
|
|
| 1270 |
source_label=SOURCE_FINE_TUNED_MODEL,
|
| 1271 |
)
|
| 1272 |
|
| 1273 |
-
sql_text, _chat_text, validator = model_core.format_generation_result(
|
|
|
|
|
|
|
|
|
|
| 1274 |
model_def = model_by_key(loaded_key)
|
| 1275 |
if not sql_text:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1276 |
return _response_tuple(
|
| 1277 |
chat_history,
|
| 1278 |
message,
|
| 1279 |
state,
|
| 1280 |
-
|
| 1281 |
-
f"{SOURCE_FINE_TUNED_MODEL}. Rejected non-SELECT/WITH model output from {model_def['model_id']} in {elapsed}s.",
|
| 1282 |
sql_text="",
|
| 1283 |
validator=validator,
|
| 1284 |
status_kind="error",
|
|
@@ -1308,12 +1317,12 @@ def normalize_sql_question_to_english(message, schema=""):
|
|
| 1308 |
return sql_core.normalize_sql_question_to_english(message, schema)
|
| 1309 |
|
| 1310 |
|
| 1311 |
-
def format_generation_result(text):
|
| 1312 |
-
return model_core.format_generation_result(text)
|
| 1313 |
|
| 1314 |
|
| 1315 |
-
def validate_sql(sql_text):
|
| 1316 |
-
return sql_core.validate_sql(sql_text)
|
| 1317 |
|
| 1318 |
|
| 1319 |
def create_table_from_message(message):
|
|
@@ -1359,7 +1368,7 @@ def sync_on_load():
|
|
| 1359 |
CSS = """
|
| 1360 |
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;500;700&display=swap');
|
| 1361 |
|
| 1362 |
-
/*
|
| 1363 |
[class*="badge"],
|
| 1364 |
[class*="validator-"],
|
| 1365 |
[class*="model-tag"],
|
|
@@ -1386,6 +1395,12 @@ CSS = """
|
|
| 1386 |
--amber-text: #854F0B;
|
| 1387 |
}
|
| 1388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1389 |
* {
|
| 1390 |
box-sizing: border-box;
|
| 1391 |
}
|
|
@@ -1397,6 +1412,36 @@ CSS = """
|
|
| 1397 |
background: var(--bg-base) !important;
|
| 1398 |
color: var(--text-primary) !important;
|
| 1399 |
font-family: Space Mono, ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1400 |
}
|
| 1401 |
|
| 1402 |
.app-shell {
|
|
@@ -1561,7 +1606,7 @@ CSS = """
|
|
| 1561 |
}
|
| 1562 |
|
| 1563 |
.model-card.selected .model-score span {
|
| 1564 |
-
color: var(--teal);
|
| 1565 |
}
|
| 1566 |
|
| 1567 |
.model-score small,
|
|
@@ -1779,7 +1824,7 @@ CSS = """
|
|
| 1779 |
}
|
| 1780 |
|
| 1781 |
.schema-context span {
|
| 1782 |
-
color: var(--teal);
|
| 1783 |
font-size: 11px;
|
| 1784 |
font-weight: 500;
|
| 1785 |
}
|
|
|
|
| 1270 |
source_label=SOURCE_FINE_TUNED_MODEL,
|
| 1271 |
)
|
| 1272 |
|
| 1273 |
+
sql_text, _chat_text, validator = model_core.format_generation_result(
|
| 1274 |
+
generated_text,
|
| 1275 |
+
state.active_schema,
|
| 1276 |
+
)
|
| 1277 |
model_def = model_by_key(loaded_key)
|
| 1278 |
if not sql_text:
|
| 1279 |
+
rejection_reason = model_core.model_sql_rejection_reason(generated_text, state.active_schema)
|
| 1280 |
+
rejection_detail = (
|
| 1281 |
+
f"The fine-tuned model output was rejected because {rejection_reason}."
|
| 1282 |
+
if rejection_reason
|
| 1283 |
+
else "The fine-tuned model output was rejected by SQL/schema guardrails."
|
| 1284 |
+
)
|
| 1285 |
return _response_tuple(
|
| 1286 |
chat_history,
|
| 1287 |
message,
|
| 1288 |
state,
|
| 1289 |
+
rejection_detail,
|
| 1290 |
+
f"{SOURCE_FINE_TUNED_MODEL}. Rejected non-SELECT/WITH model output or schema-invalid model output from {model_def['model_id']} in {elapsed}s.",
|
| 1291 |
sql_text="",
|
| 1292 |
validator=validator,
|
| 1293 |
status_kind="error",
|
|
|
|
| 1317 |
return sql_core.normalize_sql_question_to_english(message, schema)
|
| 1318 |
|
| 1319 |
|
| 1320 |
+
def format_generation_result(text, schema=""):
|
| 1321 |
+
return model_core.format_generation_result(text, schema)
|
| 1322 |
|
| 1323 |
|
| 1324 |
+
def validate_sql(sql_text, schema=""):
|
| 1325 |
+
return sql_core.validate_sql(sql_text, schema)
|
| 1326 |
|
| 1327 |
|
| 1328 |
def create_table_from_message(message):
|
|
|
|
| 1368 |
CSS = """
|
| 1369 |
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;500;700&display=swap');
|
| 1370 |
|
| 1371 |
+
/* Keep app contrast stable regardless of Spaces light/dark host theme. */
|
| 1372 |
[class*="badge"],
|
| 1373 |
[class*="validator-"],
|
| 1374 |
[class*="model-tag"],
|
|
|
|
| 1395 |
--amber-text: #854F0B;
|
| 1396 |
}
|
| 1397 |
|
| 1398 |
+
html,
|
| 1399 |
+
body,
|
| 1400 |
+
:root {
|
| 1401 |
+
color-scheme: dark !important;
|
| 1402 |
+
}
|
| 1403 |
+
|
| 1404 |
* {
|
| 1405 |
box-sizing: border-box;
|
| 1406 |
}
|
|
|
|
| 1412 |
background: var(--bg-base) !important;
|
| 1413 |
color: var(--text-primary) !important;
|
| 1414 |
font-family: Space Mono, ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important;
|
| 1415 |
+
--body-text-color: var(--text-primary) !important;
|
| 1416 |
+
--body-text-color-subdued: var(--text-secondary) !important;
|
| 1417 |
+
--block-title-text-color: var(--text-secondary) !important;
|
| 1418 |
+
--block-label-text-color: var(--text-secondary) !important;
|
| 1419 |
+
--input-placeholder-color: var(--text-muted) !important;
|
| 1420 |
+
}
|
| 1421 |
+
|
| 1422 |
+
.top-panel h1,
|
| 1423 |
+
.model-card h3,
|
| 1424 |
+
.model-score span,
|
| 1425 |
+
.evidence-copy h2,
|
| 1426 |
+
.evidence-card strong,
|
| 1427 |
+
.loading-title {
|
| 1428 |
+
color: var(--text-primary) !important;
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
.top-panel p,
|
| 1432 |
+
.step-title,
|
| 1433 |
+
.model-card code,
|
| 1434 |
+
.model-score small,
|
| 1435 |
+
.model-card-footer,
|
| 1436 |
+
.evidence-copy p,
|
| 1437 |
+
.evidence-card span,
|
| 1438 |
+
.evidence-card small,
|
| 1439 |
+
.status-pill,
|
| 1440 |
+
.schema-context,
|
| 1441 |
+
.field-label,
|
| 1442 |
+
.preset-label,
|
| 1443 |
+
.message-box {
|
| 1444 |
+
color: var(--text-secondary) !important;
|
| 1445 |
}
|
| 1446 |
|
| 1447 |
.app-shell {
|
|
|
|
| 1606 |
}
|
| 1607 |
|
| 1608 |
.model-card.selected .model-score span {
|
| 1609 |
+
color: var(--teal) !important;
|
| 1610 |
}
|
| 1611 |
|
| 1612 |
.model-score small,
|
|
|
|
| 1824 |
}
|
| 1825 |
|
| 1826 |
.schema-context span {
|
| 1827 |
+
color: var(--teal) !important;
|
| 1828 |
font-size: 11px;
|
| 1829 |
font-weight: 500;
|
| 1830 |
}
|
model_io.py
CHANGED
|
@@ -59,33 +59,48 @@ def is_sql_like(text):
|
|
| 59 |
return sql_tools.is_sql_like(text)
|
| 60 |
|
| 61 |
|
| 62 |
-
def
|
| 63 |
text = (text or "").strip()
|
| 64 |
if not text:
|
| 65 |
-
return
|
| 66 |
try:
|
| 67 |
statements = [statement for statement in sqlparse.parse(text) if str(statement).strip()]
|
| 68 |
except Exception:
|
| 69 |
-
return
|
| 70 |
if len(statements) != 1:
|
| 71 |
-
return
|
| 72 |
|
| 73 |
statement = statements[0]
|
| 74 |
first_token = statement.token_first(skip_cm=True)
|
| 75 |
starter = first_token.value.strip().upper() if first_token is not None else ""
|
| 76 |
if starter not in MODEL_SQL_STARTERS or statement.get_type().upper() != "SELECT":
|
| 77 |
-
return
|
| 78 |
|
| 79 |
for token in statement.flatten():
|
| 80 |
keyword = token.value.strip().upper()
|
| 81 |
if token.ttype in (sql_tokens.Keyword.DDL, sql_tokens.Keyword.DML):
|
| 82 |
if keyword in MODEL_SQL_FORBIDDEN_KEYWORDS:
|
| 83 |
-
return
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
cleaned = extract_sql_candidate(text)
|
| 89 |
-
if is_model_sql_allowed(cleaned):
|
| 90 |
-
return str(cleaned), "", sql_tools.validate_sql(cleaned)
|
| 91 |
return "", "", sql_tools.validate_sql("")
|
|
|
|
| 59 |
return sql_tools.is_sql_like(text)
|
| 60 |
|
| 61 |
|
| 62 |
+
def model_sql_validation_issue(text, schema=""):
|
| 63 |
text = (text or "").strip()
|
| 64 |
if not text:
|
| 65 |
+
return "empty model output"
|
| 66 |
try:
|
| 67 |
statements = [statement for statement in sqlparse.parse(text) if str(statement).strip()]
|
| 68 |
except Exception:
|
| 69 |
+
return "sqlparse could not parse model output"
|
| 70 |
if len(statements) != 1:
|
| 71 |
+
return "model output contains multiple SQL statements"
|
| 72 |
|
| 73 |
statement = statements[0]
|
| 74 |
first_token = statement.token_first(skip_cm=True)
|
| 75 |
starter = first_token.value.strip().upper() if first_token is not None else ""
|
| 76 |
if starter not in MODEL_SQL_STARTERS or statement.get_type().upper() != "SELECT":
|
| 77 |
+
return "model output is not SELECT/WITH SQL"
|
| 78 |
|
| 79 |
for token in statement.flatten():
|
| 80 |
keyword = token.value.strip().upper()
|
| 81 |
if token.ttype in (sql_tokens.Keyword.DDL, sql_tokens.Keyword.DML):
|
| 82 |
if keyword in MODEL_SQL_FORBIDDEN_KEYWORDS:
|
| 83 |
+
return f"model output contains unsupported SQL keyword: {keyword}"
|
| 84 |
+
schema_issue = sql_tools.sql_schema_validation_issue(text, schema)
|
| 85 |
+
if schema_issue:
|
| 86 |
+
return schema_issue
|
| 87 |
+
validator = sql_tools.validate_sql(text, schema)
|
| 88 |
+
if "validator-ok" not in validator:
|
| 89 |
+
return "model output failed SQL/schema validation"
|
| 90 |
+
return ""
|
| 91 |
|
| 92 |
|
| 93 |
+
def is_model_sql_allowed(text, schema=""):
|
| 94 |
+
return not model_sql_validation_issue(text, schema)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def model_sql_rejection_reason(text, schema=""):
|
| 98 |
+
cleaned = extract_sql_candidate(text)
|
| 99 |
+
return model_sql_validation_issue(cleaned, schema)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def format_generation_result(text, schema=""):
|
| 103 |
cleaned = extract_sql_candidate(text)
|
| 104 |
+
if is_model_sql_allowed(cleaned, schema):
|
| 105 |
+
return str(cleaned), "", sql_tools.validate_sql(cleaned, schema)
|
| 106 |
return "", "", sql_tools.validate_sql("")
|
sql_tools.py
CHANGED
|
@@ -198,7 +198,133 @@ def is_sql_intent(message, schema=""):
|
|
| 198 |
return bool(schema and is_sql_like(message))
|
| 199 |
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
sql_text = (sql_text or "").strip()
|
| 203 |
if not sql_text:
|
| 204 |
return '<span class="validator-badge validator-empty">No SQL yet</span>'
|
|
@@ -224,7 +350,7 @@ def validate_sql(sql_text):
|
|
| 224 |
f'<span class="validator-detail">First token: {escaped_token}</span>'
|
| 225 |
)
|
| 226 |
trailing_keyword = re.search(
|
| 227 |
-
r"\b(AND|
|
| 228 |
sql_text,
|
| 229 |
flags=re.IGNORECASE,
|
| 230 |
)
|
|
@@ -245,6 +371,13 @@ def validate_sql(sql_text):
|
|
| 245 |
'<span class="validator-badge validator-warn">Check syntax</span>'
|
| 246 |
f'<span class="validator-detail">Incomplete negated predicate: NOT {escaped_token}</span>'
|
| 247 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
return '<span class="validator-badge validator-ok">Valid SQL</span>'
|
| 249 |
|
| 250 |
|
|
|
|
| 198 |
return bool(schema and is_sql_like(message))
|
| 199 |
|
| 200 |
|
| 201 |
+
SQL_SCHEMA_FUNCTION_NAMES = {
|
| 202 |
+
"AVG",
|
| 203 |
+
"COALESCE",
|
| 204 |
+
"COUNT",
|
| 205 |
+
"LOWER",
|
| 206 |
+
"MAX",
|
| 207 |
+
"MIN",
|
| 208 |
+
"ROUND",
|
| 209 |
+
"SUM",
|
| 210 |
+
"UPPER",
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
SQL_ALIAS_STOPWORDS = {
|
| 214 |
+
"FULL",
|
| 215 |
+
"GROUP",
|
| 216 |
+
"HAVING",
|
| 217 |
+
"INNER",
|
| 218 |
+
"JOIN",
|
| 219 |
+
"LEFT",
|
| 220 |
+
"LIMIT",
|
| 221 |
+
"ON",
|
| 222 |
+
"ORDER",
|
| 223 |
+
"RIGHT",
|
| 224 |
+
"WHERE",
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _without_sql_literals(sql_text):
|
| 229 |
+
return re.sub(
|
| 230 |
+
r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"|\b\d+(?:\.\d+)?\b",
|
| 231 |
+
" ",
|
| 232 |
+
sql_text or "",
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _identifier_names(sql_text):
|
| 237 |
+
try:
|
| 238 |
+
statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()]
|
| 239 |
+
except Exception:
|
| 240 |
+
return []
|
| 241 |
+
names = []
|
| 242 |
+
for statement in statements:
|
| 243 |
+
flattened = list(statement.flatten())
|
| 244 |
+
for index, token in enumerate(flattened):
|
| 245 |
+
if token.ttype not in sqlparse.tokens.Name:
|
| 246 |
+
continue
|
| 247 |
+
previous_value = ""
|
| 248 |
+
next_value = ""
|
| 249 |
+
for previous in reversed(flattened[:index]):
|
| 250 |
+
if not previous.is_whitespace:
|
| 251 |
+
previous_value = previous.value
|
| 252 |
+
break
|
| 253 |
+
for next_token in flattened[index + 1:]:
|
| 254 |
+
if not next_token.is_whitespace:
|
| 255 |
+
next_value = next_token.value
|
| 256 |
+
break
|
| 257 |
+
names.append((token.value, previous_value, next_value))
|
| 258 |
+
return names
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def sql_schema_validation_issue(sql_text, schema):
|
| 262 |
+
table_name, columns = parse_create_table_schema(schema)
|
| 263 |
+
if not table_name or not columns:
|
| 264 |
+
return ""
|
| 265 |
+
|
| 266 |
+
expected_table = table_name.lower()
|
| 267 |
+
allowed_columns = {name.lower() for name, _column_type in columns}
|
| 268 |
+
scrubbed_sql = _without_sql_literals(sql_text)
|
| 269 |
+
|
| 270 |
+
cte_aliases = {
|
| 271 |
+
match.group(1).lower()
|
| 272 |
+
for match in re.finditer(
|
| 273 |
+
r"(?:WITH|,)\s+([A-Za-z_][\w]*)\s+AS\s*\(",
|
| 274 |
+
scrubbed_sql,
|
| 275 |
+
flags=re.IGNORECASE,
|
| 276 |
+
)
|
| 277 |
+
}
|
| 278 |
+
table_refs = [
|
| 279 |
+
match.group(1).lower()
|
| 280 |
+
for match in re.finditer(
|
| 281 |
+
r"\b(?:FROM|JOIN)\s+([A-Za-z_][\w]*)",
|
| 282 |
+
scrubbed_sql,
|
| 283 |
+
flags=re.IGNORECASE,
|
| 284 |
+
)
|
| 285 |
+
]
|
| 286 |
+
if not table_refs:
|
| 287 |
+
return f"Model SQL does not reference active table: {table_name}"
|
| 288 |
+
for table_ref in table_refs:
|
| 289 |
+
if table_ref not in {expected_table, *cte_aliases}:
|
| 290 |
+
return f"Unknown table for active schema: {table_ref}"
|
| 291 |
+
|
| 292 |
+
table_aliases = set()
|
| 293 |
+
table_alias_pattern = (
|
| 294 |
+
r"\b(?:FROM|JOIN)\s+"
|
| 295 |
+
rf"({re.escape(table_name)})"
|
| 296 |
+
r"\s+(?:AS\s+)?([A-Za-z_][\w]*)"
|
| 297 |
+
)
|
| 298 |
+
for match in re.finditer(table_alias_pattern, scrubbed_sql, flags=re.IGNORECASE):
|
| 299 |
+
alias = match.group(2)
|
| 300 |
+
if alias.upper() not in SQL_ALIAS_STOPWORDS:
|
| 301 |
+
table_aliases.add(alias.lower())
|
| 302 |
+
|
| 303 |
+
output_aliases = {
|
| 304 |
+
match.group(1).lower()
|
| 305 |
+
for match in re.finditer(r"\bAS\s+([A-Za-z_][\w]*)", scrubbed_sql, flags=re.IGNORECASE)
|
| 306 |
+
}
|
| 307 |
+
allowed_non_columns = {
|
| 308 |
+
expected_table,
|
| 309 |
+
*cte_aliases,
|
| 310 |
+
*table_aliases,
|
| 311 |
+
*output_aliases,
|
| 312 |
+
}
|
| 313 |
+
for name, _previous_value, next_value in _identifier_names(sql_text):
|
| 314 |
+
normalized_name = name.lower()
|
| 315 |
+
if next_value == ".":
|
| 316 |
+
continue
|
| 317 |
+
if normalized_name in allowed_non_columns:
|
| 318 |
+
continue
|
| 319 |
+
if name.upper() in SQL_SCHEMA_FUNCTION_NAMES:
|
| 320 |
+
continue
|
| 321 |
+
if normalized_name in allowed_columns:
|
| 322 |
+
continue
|
| 323 |
+
return f"Unknown column for active schema: {name}"
|
| 324 |
+
return ""
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def validate_sql(sql_text, schema=""):
|
| 328 |
sql_text = (sql_text or "").strip()
|
| 329 |
if not sql_text:
|
| 330 |
return '<span class="validator-badge validator-empty">No SQL yet</span>'
|
|
|
|
| 350 |
f'<span class="validator-detail">First token: {escaped_token}</span>'
|
| 351 |
)
|
| 352 |
trailing_keyword = re.search(
|
| 353 |
+
r"\b(AND|BY|FROM|GROUP|HAVING|JOIN|LIMIT|NOT|ON|OR|ORDER|SELECT|WHERE)\s*;?\s*$",
|
| 354 |
sql_text,
|
| 355 |
flags=re.IGNORECASE,
|
| 356 |
)
|
|
|
|
| 371 |
'<span class="validator-badge validator-warn">Check syntax</span>'
|
| 372 |
f'<span class="validator-detail">Incomplete negated predicate: NOT {escaped_token}</span>'
|
| 373 |
)
|
| 374 |
+
schema_issue = sql_schema_validation_issue(sql_text, schema)
|
| 375 |
+
if schema_issue:
|
| 376 |
+
escaped_issue = html.escape(schema_issue)
|
| 377 |
+
return (
|
| 378 |
+
'<span class="validator-badge validator-warn">Check schema</span>'
|
| 379 |
+
f'<span class="validator-detail">{escaped_issue}</span>'
|
| 380 |
+
)
|
| 381 |
return '<span class="validator-badge validator-ok">Valid SQL</span>'
|
| 382 |
|
| 383 |
|
tests/test_chatbot_behavior.py
CHANGED
|
@@ -500,6 +500,55 @@ def test_sql_model_rejects_non_sql_output_as_chat_capability(monkeypatch):
|
|
| 500 |
assert "Rejected non-SELECT/WITH model output" in status_html(result)
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
def test_sql_intent_detected():
|
| 504 |
assert app.is_sql_intent("What is the average salary per department?", app.PRESETS["employees"])
|
| 505 |
assert app.is_sql_intent("what is the most expensive product?", app.PRESETS["products"])
|
|
@@ -595,6 +644,16 @@ def test_format_generation_result_accepts_with_query():
|
|
| 595 |
assert "validator-ok" in validator
|
| 596 |
|
| 597 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
# ---------------------------------------------------------------------------
|
| 599 |
# validate_sql — starters beyond SELECT
|
| 600 |
# ---------------------------------------------------------------------------
|
|
@@ -625,6 +684,34 @@ def test_validate_sql_bare_negated_predicate_returns_warn():
|
|
| 625 |
)
|
| 626 |
|
| 627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
def test_validate_sql_empty_returns_empty_badge():
|
| 629 |
assert app.validate_sql("") == app.EMPTY_VALIDATOR
|
| 630 |
|
|
|
|
| 500 |
assert "Rejected non-SELECT/WITH model output" in status_html(result)
|
| 501 |
|
| 502 |
|
| 503 |
+
def test_sql_model_rejects_hallucinated_schema_column(monkeypatch):
|
| 504 |
+
app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
|
| 505 |
+
app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
|
| 506 |
+
app._current_model_id = app.FINE_TUNED_MODEL_ID
|
| 507 |
+
|
| 508 |
+
monkeypatch.setattr(
|
| 509 |
+
app,
|
| 510 |
+
"_generate_model_text",
|
| 511 |
+
lambda *a, **k: ("SELECT email FROM employees WHERE name = 'Alice';", 1),
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
result = app.generate_response(
|
| 515 |
+
"find employees named Alice",
|
| 516 |
+
[],
|
| 517 |
+
app.PRESETS["employees"],
|
| 518 |
+
app.FINE_TUNED_MODEL_KEY,
|
| 519 |
+
None,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
assert sql_output(result) == ""
|
| 523 |
+
assert "email" in assistant_text(result)
|
| 524 |
+
assert "schema-invalid model output" in status_html(result)
|
| 525 |
+
assert app.SOURCE_FINE_TUNED_MODEL in status_html(result)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def test_sql_model_accepts_valid_schema_column_output(monkeypatch):
|
| 529 |
+
app._model = types.SimpleNamespace(generation_config=types.SimpleNamespace(eos_token_id=0))
|
| 530 |
+
app._tokenizer = types.SimpleNamespace(eos_token_id=0, pad_token_id=0)
|
| 531 |
+
app._current_model_id = app.FINE_TUNED_MODEL_ID
|
| 532 |
+
|
| 533 |
+
monkeypatch.setattr(
|
| 534 |
+
app,
|
| 535 |
+
"_generate_model_text",
|
| 536 |
+
lambda *a, **k: ("SELECT * FROM employees WHERE name = 'Alice';", 1),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
result = app.generate_response(
|
| 540 |
+
"find employees named Alice",
|
| 541 |
+
[],
|
| 542 |
+
app.PRESETS["employees"],
|
| 543 |
+
app.FINE_TUNED_MODEL_KEY,
|
| 544 |
+
None,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
assert "SELECT * FROM employees WHERE name = 'Alice';" in sql_output(result)
|
| 548 |
+
assert "validator-ok" in result[5]
|
| 549 |
+
assert app.SOURCE_FINE_TUNED_MODEL in status_html(result)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
def test_sql_intent_detected():
|
| 553 |
assert app.is_sql_intent("What is the average salary per department?", app.PRESETS["employees"])
|
| 554 |
assert app.is_sql_intent("what is the most expensive product?", app.PRESETS["products"])
|
|
|
|
| 644 |
assert "validator-ok" in validator
|
| 645 |
|
| 646 |
|
| 647 |
+
def test_format_generation_result_rejects_unknown_column_against_schema():
|
| 648 |
+
sql, chat, validator = app.format_generation_result(
|
| 649 |
+
"SELECT email FROM employees;",
|
| 650 |
+
app.PRESETS["employees"],
|
| 651 |
+
)
|
| 652 |
+
assert sql == ""
|
| 653 |
+
assert chat == ""
|
| 654 |
+
assert validator == app.EMPTY_VALIDATOR
|
| 655 |
+
|
| 656 |
+
|
| 657 |
# ---------------------------------------------------------------------------
|
| 658 |
# validate_sql — starters beyond SELECT
|
| 659 |
# ---------------------------------------------------------------------------
|
|
|
|
| 684 |
)
|
| 685 |
|
| 686 |
|
| 687 |
+
def test_validate_sql_rejects_unknown_schema_column():
|
| 688 |
+
assert "validator-warn" in app.validate_sql(
|
| 689 |
+
"SELECT email FROM employees;",
|
| 690 |
+
app.PRESETS["employees"],
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def test_validate_sql_rejects_unknown_schema_table():
|
| 695 |
+
assert "validator-warn" in app.validate_sql(
|
| 696 |
+
"SELECT name FROM departments;",
|
| 697 |
+
app.PRESETS["employees"],
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def test_validate_sql_rejects_date_when_not_in_schema():
|
| 702 |
+
assert "validator-warn" in app.validate_sql(
|
| 703 |
+
"SELECT date FROM employees;",
|
| 704 |
+
app.PRESETS["employees"],
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def test_validate_sql_accepts_schema_alias_and_output_alias():
|
| 709 |
+
assert "validator-ok" in app.validate_sql(
|
| 710 |
+
"SELECT e.name, COUNT(*) AS total FROM employees e GROUP BY e.name ORDER BY total DESC;",
|
| 711 |
+
app.PRESETS["employees"],
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
|
| 715 |
def test_validate_sql_empty_returns_empty_badge():
|
| 716 |
assert app.validate_sql("") == app.EMPTY_VALIDATOR
|
| 717 |
|
tests/test_chatbot_core.py
CHANGED
|
@@ -100,7 +100,7 @@ def test_zoologico_transcript_with_mocked_sql_model(monkeypatch):
|
|
| 100 |
def fake_generate(prompt, generation_kind):
|
| 101 |
assert generation_kind == app.model_core.SQL_GENERATION
|
| 102 |
assert "CREATE TABLE zoologico" in prompt
|
| 103 |
-
return "SELECT * FROM zoologico WHERE
|
| 104 |
|
| 105 |
monkeypatch.setattr(app, "_generate_model_text", fake_generate)
|
| 106 |
|
|
|
|
| 100 |
def fake_generate(prompt, generation_kind):
|
| 101 |
assert generation_kind == app.model_core.SQL_GENERATION
|
| 102 |
assert "CREATE TABLE zoologico" in prompt
|
| 103 |
+
return "SELECT * FROM zoologico WHERE city = 'Sao Paulo';", 1
|
| 104 |
|
| 105 |
monkeypatch.setattr(app, "_generate_model_text", fake_generate)
|
| 106 |
|