Shizu0n's picture
fix: hide schema internals in model rejection
1104ea5
import sqlparse
from sqlparse import tokens as sql_tokens
import sql_tools
SQL_GENERATION = "sql"
MODEL_SQL_STARTERS = {"SELECT", "WITH"}
MODEL_SQL_FORBIDDEN_KEYWORDS = {
"ALTER",
"CREATE",
"DELETE",
"DROP",
"INSERT",
"MERGE",
"REPLACE",
"TRUNCATE",
"UPDATE",
}
GENERATION_BUDGETS = {
SQL_GENERATION: 96,
}
SQL_PROMPT_TEMPLATE = (
"<|user|>\n"
"Given the following SQL table, write one SQL query. Output SQL only.\n\n"
"Table: {schema}\n\n"
"Question: {question}<|end|>\n"
"<|assistant|>"
)
def build_sql_prompt(schema, message, chat_history=None):
table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
question = sql_tools.normalize_sql_question_to_english(message, table_schema)
return SQL_PROMPT_TEMPLATE.format(
schema=table_schema,
question=question,
)
def build_generation_prompt(schema, message, chat_history=None):
return build_sql_prompt(schema, message, chat_history)
def generation_budget(kind):
return GENERATION_BUDGETS.get(kind, GENERATION_BUDGETS[SQL_GENERATION])
def clean_generation(text):
return sql_tools.clean_generation(text)
def extract_sql_candidate(text):
return sql_tools.extract_sql_candidate(text)
def is_sql_like(text):
return sql_tools.is_sql_like(text)
def model_sql_validation_issue(text, schema=""):
text = (text or "").strip()
if not text:
return "empty model output"
try:
statements = [statement for statement in sqlparse.parse(text) if str(statement).strip()]
except Exception:
return "sqlparse could not parse model output"
if len(statements) != 1:
return "model output contains multiple SQL statements"
statement = statements[0]
first_token = statement.token_first(skip_cm=True)
starter = first_token.value.strip().upper() if first_token is not None else ""
if starter not in MODEL_SQL_STARTERS or statement.get_type().upper() != "SELECT":
return "model output is not SELECT/WITH SQL"
for token in statement.flatten():
keyword = token.value.strip().upper()
if token.ttype in (sql_tokens.Keyword.DDL, sql_tokens.Keyword.DML):
if keyword in MODEL_SQL_FORBIDDEN_KEYWORDS:
return f"model output contains unsupported SQL keyword: {keyword}"
schema_issue = sql_tools.sql_schema_validation_issue(text, schema)
if schema_issue:
return "model output failed SQL/schema validation"
validator = sql_tools.validate_sql(text, schema)
if "validator-ok" not in validator:
return "model output failed SQL/schema validation"
return ""
def is_model_sql_allowed(text, schema=""):
return not model_sql_validation_issue(text, schema)
def model_sql_rejection_reason(text, schema=""):
cleaned = extract_sql_candidate(text)
return model_sql_validation_issue(cleaned, schema)
def format_generation_result(text, schema=""):
cleaned = extract_sql_candidate(text)
if is_model_sql_allowed(cleaned, schema):
return str(cleaned), "", sql_tools.validate_sql(cleaned, schema)
return "", "", sql_tools.validate_sql("")