| import sqlparse |
| from sqlparse import tokens as sql_tokens |
| import sql_tools |
|
|
|
|
| SQL_GENERATION = "sql" |
| MODEL_SQL_STARTERS = {"SELECT", "WITH"} |
| MODEL_SQL_FORBIDDEN_KEYWORDS = { |
| "ALTER", |
| "CREATE", |
| "DELETE", |
| "DROP", |
| "INSERT", |
| "MERGE", |
| "REPLACE", |
| "TRUNCATE", |
| "UPDATE", |
| } |
|
|
| GENERATION_BUDGETS = { |
| SQL_GENERATION: 96, |
| } |
|
|
| SQL_PROMPT_TEMPLATE = ( |
| "<|user|>\n" |
| "Given the following SQL table, write one SQL query. Output SQL only.\n\n" |
| "Table: {schema}\n\n" |
| "Question: {question}<|end|>\n" |
| "<|assistant|>" |
| ) |
|
|
|
|
| def build_sql_prompt(schema, message, chat_history=None): |
| table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)" |
| question = sql_tools.normalize_sql_question_to_english(message, table_schema) |
| return SQL_PROMPT_TEMPLATE.format( |
| schema=table_schema, |
| question=question, |
| ) |
|
|
|
|
| def build_generation_prompt(schema, message, chat_history=None): |
| return build_sql_prompt(schema, message, chat_history) |
|
|
|
|
| def generation_budget(kind): |
| return GENERATION_BUDGETS.get(kind, GENERATION_BUDGETS[SQL_GENERATION]) |
|
|
|
|
| def clean_generation(text): |
| return sql_tools.clean_generation(text) |
|
|
|
|
| def extract_sql_candidate(text): |
| return sql_tools.extract_sql_candidate(text) |
|
|
|
|
| def is_sql_like(text): |
| return sql_tools.is_sql_like(text) |
|
|
|
|
| def model_sql_validation_issue(text, schema=""): |
| text = (text or "").strip() |
| if not text: |
| return "empty model output" |
| try: |
| statements = [statement for statement in sqlparse.parse(text) if str(statement).strip()] |
| except Exception: |
| return "sqlparse could not parse model output" |
| if len(statements) != 1: |
| return "model output contains multiple SQL statements" |
|
|
| statement = statements[0] |
| first_token = statement.token_first(skip_cm=True) |
| starter = first_token.value.strip().upper() if first_token is not None else "" |
| if starter not in MODEL_SQL_STARTERS or statement.get_type().upper() != "SELECT": |
| return "model output is not SELECT/WITH SQL" |
|
|
| for token in statement.flatten(): |
| keyword = token.value.strip().upper() |
| if token.ttype in (sql_tokens.Keyword.DDL, sql_tokens.Keyword.DML): |
| if keyword in MODEL_SQL_FORBIDDEN_KEYWORDS: |
| return f"model output contains unsupported SQL keyword: {keyword}" |
| schema_issue = sql_tools.sql_schema_validation_issue(text, schema) |
| if schema_issue: |
| return "model output failed SQL/schema validation" |
| validator = sql_tools.validate_sql(text, schema) |
| if "validator-ok" not in validator: |
| return "model output failed SQL/schema validation" |
| return "" |
|
|
|
|
| def is_model_sql_allowed(text, schema=""): |
| return not model_sql_validation_issue(text, schema) |
|
|
|
|
| def model_sql_rejection_reason(text, schema=""): |
| cleaned = extract_sql_candidate(text) |
| return model_sql_validation_issue(cleaned, schema) |
|
|
|
|
| def format_generation_result(text, schema=""): |
| cleaned = extract_sql_candidate(text) |
| if is_model_sql_allowed(cleaned, schema): |
| return str(cleaned), "", sql_tools.validate_sql(cleaned, schema) |
| return "", "", sql_tools.validate_sql("") |
|
|