File size: 3,174 Bytes
5b92375
 
47affa0
 
 
 
5b92375
 
 
 
 
 
 
 
 
 
 
 
47affa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d88f966
47affa0
 
d88f966
47affa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737eaac
5b92375
 
737eaac
5b92375
 
 
737eaac
5b92375
737eaac
5b92375
 
 
 
 
737eaac
5b92375
 
 
 
 
737eaac
 
 
1104ea5
737eaac
 
 
 
5b92375
 
737eaac
 
 
 
 
 
 
 
 
 
47affa0
737eaac
 
5b92375
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import sqlparse
from sqlparse import tokens as sql_tokens
import sql_tools


SQL_GENERATION = "sql"
MODEL_SQL_STARTERS = {"SELECT", "WITH"}
MODEL_SQL_FORBIDDEN_KEYWORDS = {
    "ALTER",
    "CREATE",
    "DELETE",
    "DROP",
    "INSERT",
    "MERGE",
    "REPLACE",
    "TRUNCATE",
    "UPDATE",
}

GENERATION_BUDGETS = {
    SQL_GENERATION: 96,
}

SQL_PROMPT_TEMPLATE = (
    "<|user|>\n"
    "Given the following SQL table, write one SQL query. Output SQL only.\n\n"
    "Table: {schema}\n\n"
    "Question: {question}<|end|>\n"
    "<|assistant|>"
)


def build_sql_prompt(schema, message, chat_history=None):
    table_schema = (schema or "").strip() or "CREATE TABLE unknown (id INTEGER)"
    question = sql_tools.normalize_sql_question_to_english(message, table_schema)
    return SQL_PROMPT_TEMPLATE.format(
        schema=table_schema,
        question=question,
    )


def build_generation_prompt(schema, message, chat_history=None):
    return build_sql_prompt(schema, message, chat_history)


def generation_budget(kind):
    return GENERATION_BUDGETS.get(kind, GENERATION_BUDGETS[SQL_GENERATION])


def clean_generation(text):
    return sql_tools.clean_generation(text)


def extract_sql_candidate(text):
    return sql_tools.extract_sql_candidate(text)


def is_sql_like(text):
    return sql_tools.is_sql_like(text)


def model_sql_validation_issue(text, schema=""):
    text = (text or "").strip()
    if not text:
        return "empty model output"
    try:
        statements = [statement for statement in sqlparse.parse(text) if str(statement).strip()]
    except Exception:
        return "sqlparse could not parse model output"
    if len(statements) != 1:
        return "model output contains multiple SQL statements"

    statement = statements[0]
    first_token = statement.token_first(skip_cm=True)
    starter = first_token.value.strip().upper() if first_token is not None else ""
    if starter not in MODEL_SQL_STARTERS or statement.get_type().upper() != "SELECT":
        return "model output is not SELECT/WITH SQL"

    for token in statement.flatten():
        keyword = token.value.strip().upper()
        if token.ttype in (sql_tokens.Keyword.DDL, sql_tokens.Keyword.DML):
            if keyword in MODEL_SQL_FORBIDDEN_KEYWORDS:
                return f"model output contains unsupported SQL keyword: {keyword}"
    schema_issue = sql_tools.sql_schema_validation_issue(text, schema)
    if schema_issue:
        return "model output failed SQL/schema validation"
    validator = sql_tools.validate_sql(text, schema)
    if "validator-ok" not in validator:
        return "model output failed SQL/schema validation"
    return ""


def is_model_sql_allowed(text, schema=""):
    return not model_sql_validation_issue(text, schema)


def model_sql_rejection_reason(text, schema=""):
    cleaned = extract_sql_candidate(text)
    return model_sql_validation_issue(cleaned, schema)


def format_generation_result(text, schema=""):
    cleaned = extract_sql_candidate(text)
    if is_model_sql_allowed(cleaned, schema):
        return str(cleaned), "", sql_tools.validate_sql(cleaned, schema)
    return "", "", sql_tools.validate_sql("")