import sqlparse from sqlparse import tokens as sql_tokens import sql_tools SQL_GENERATION = "sql" MODEL_SQL_STARTERS = {"SELECT", "WITH"} MODEL_SQL_FORBIDDEN_KEYWORDS = { "ALTER", "CREATE", "DELETE", "DROP", "INSERT", "MERGE", "REPLACE", "TRUNCATE", "UPDATE", } GENERATION_BUDGETS = { SQL_GENERATION: 96, } SQL_PROMPT_TEMPLATE = ( "<|user|>\n" "Given the following SQL table, write one SQL query. Output SQL only.\n\n" "Table: {schema}\n\n" "Question: {question}<|end|>\n" "<|assistant|>" ) def build_sql_prompt(schema, message, chat_history=None): table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)" question = sql_tools.normalize_sql_question_to_english(message, table_schema) return SQL_PROMPT_TEMPLATE.format( schema=table_schema, question=question, ) def build_generation_prompt(schema, message, chat_history=None): return build_sql_prompt(schema, message, chat_history) def generation_budget(kind): return GENERATION_BUDGETS.get(kind, GENERATION_BUDGETS[SQL_GENERATION]) def clean_generation(text): return sql_tools.clean_generation(text) def extract_sql_candidate(text): return sql_tools.extract_sql_candidate(text) def is_sql_like(text): return sql_tools.is_sql_like(text) def model_sql_validation_issue(text, schema=""): text = (text or "").strip() if not text: return "empty model output" try: statements = [statement for statement in sqlparse.parse(text) if str(statement).strip()] except Exception: return "sqlparse could not parse model output" if len(statements) != 1: return "model output contains multiple SQL statements" statement = statements[0] first_token = statement.token_first(skip_cm=True) starter = first_token.value.strip().upper() if first_token is not None else "" if starter not in MODEL_SQL_STARTERS or statement.get_type().upper() != "SELECT": return "model output is not SELECT/WITH SQL" for token in statement.flatten(): keyword = token.value.strip().upper() if token.ttype in (sql_tokens.Keyword.DDL, sql_tokens.Keyword.DML): if keyword in MODEL_SQL_FORBIDDEN_KEYWORDS: return f"model output contains unsupported SQL keyword: {keyword}" schema_issue = sql_tools.sql_schema_validation_issue(text, schema) if schema_issue: return "model output failed SQL/schema validation" validator = sql_tools.validate_sql(text, schema) if "validator-ok" not in validator: return "model output failed SQL/schema validation" return "" def is_model_sql_allowed(text, schema=""): return not model_sql_validation_issue(text, schema) def model_sql_rejection_reason(text, schema=""): cleaned = extract_sql_candidate(text) return model_sql_validation_issue(cleaned, schema) def format_generation_result(text, schema=""): cleaned = extract_sql_candidate(text) if is_model_sql_allowed(cleaned, schema): return str(cleaned), "", sql_tools.validate_sql(cleaned, schema) return "", "", sql_tools.validate_sql("")