File size: 7,323 Bytes
8206fba
 
 
 
 
575c346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8206fba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f8f2ae
 
 
 
 
 
 
 
 
 
8206fba
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import sqlite3
from pathlib import Path
from typing import Any

import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)

log = logging.getLogger("llmsql-bench")
import re


def find_sql(model_output: str, limit: int = 10) -> list[str]:
    """Function to extract SQL queries from the model's response

    Args:
        model_output (str): Model's response as string
        limit (int, optional): The number of SQL queries to return. Defaults to 10.

    Returns:
        List[str]: SQL queries from input.
    """
    results = []

    # Find all SELECT keywords that could start a query
    for match in re.finditer(r"(?i)(SELECT)", model_output, re.IGNORECASE):
        start_pos = match.start(1)  # Start of SELECT word

        # Look for end of query constraints from this position:
        # semicolumn, block of at least 4 newlines, markdown code fence or just end of string.
        remaining = model_output[start_pos:]
        query_match = re.search(
            r"(?s)SELECT\b.*?(?=(?:;|\n{4,}|```|$))",
            remaining,
            re.IGNORECASE,
        )

        if query_match:
            query = query_match.group(0).strip()
            if query and query not in results:
                results.append(query)

    return results[:limit]


def execute_sql(conn: sqlite3.Connection, sql: str) -> list[tuple] | None:
    """
    Execute a SQL query on the given SQLite connection and return its results.

    The results are always sorted to avoid differences caused by row order (order agnostic).
    If the query fails, the function logs the error and returns None.

    Args:
        conn (sqlite3.Connection): An active SQLite database connection.
        sql (str): SQL query string to execute.

    Returns:
        Optional[List[Tuple]]:
            - Sorted list of result rows (each row as a tuple) if successful.
            - [(None,)] if the query executed but returned NULL values.
            - None if the SQL execution failed due to an exception.
    """
    try:
        cur = conn.cursor()
        cur.execute(sql)
        results = cur.fetchall()
        return sorted(results)
    except Exception:
        return None


def fix_table_name(sql: str, table_id: str) -> str:
    """
    Replace placeholder table name in the SQL query with the actual table ID.

    During evaluation, the LLM is instructed to always generate queries using
    a generic placeholder table name (`FROM Table`, `FROM "Table"`, or `FROM 'Table'`).
    This keeps the model’s task simpler and avoids requiring it to memorize or
    reproduce arbitrary, dataset-specific table IDs.

    This function post-processes the model’s SQL output by replacing the placeholder
    with the true table identifier for the current question.

    Args:
        sql (str): SQL query string produced by the model, using "Table" as placeholder.
        table_id (str): Actual table name/identifier for the current question.

    Returns:
        str: SQL query with the correct table name substituted.
    """
    return (
        sql.replace("FROM 'Table'", f'FROM "{table_id}"')
        .replace('FROM "Table"', f'FROM "{table_id}"')
        .replace("FROM Table", f'FROM "{table_id}"')
        .strip()
    )


def evaluate_sample(
    item: dict[str, int | str],
    questions: dict[int, dict[str, str]],
    conn: sqlite3.Connection,
) -> tuple[int, dict[str, Any] | None, dict[Any, Any]]:
    """
    Evaluate a single model prediction against the gold (ground-truth) SQL query.

    This function:
    - Retrieves the gold SQL query and question metadata for the given `question_id`.
    - Executes the gold SQL and the model's at most 10 predicted SQL queries on the SQLite DB.
    - Compares their results to determine whether the gold and at least one prediction are matched.
    - Tracks special cases such as SQL errors or queries returning NULL results.
    - Returns evaluation metrics and mismatch details (if any).

    Args:
        item (dict): A single model prediction entry. Must contain:
                     - "question_id": ID of the benchmark question.
                     - "completion": The raw SQL string predicted by the model.
        questions (dict): Dictionary mapping `question_id` → question metadata:
                          {"sql": ..., "table_id": ..., "question": ...}.
        conn (sqlite3.Connection): Active SQLite connection used to run queries.

    Returns:
        tuple:
            is_match (int): 1 if prediction matches gold SQL results, else 0.
            mismatch_info (dict or None): Details about the mismatch if incorrect,
                                          otherwise None. Includes question, gold SQL,
                                          model output, and query results.
            metrics_update (dict): Partial metrics for this prediction:
                                   {
                                     "pred_none": int,
                                     "gold_none": int,
                                     "sql_error": int
                                   }
    """
    qid = item["question_id"]
    assert isinstance(qid, int), (
        "question_id in the outputs file needs to be of type int."
    )
    q_info = questions[qid]
    table_id, gold_sql, question_text = (
        q_info["table_id"],
        q_info["sql"],
        q_info["question"],
    )

    gold_results = execute_sql(conn, gold_sql)

    pred_none = gold_none = sql_error = 0

    if gold_results == [(None,)]:
        gold_none = 1

    # Flag for whether the prediction was correct
    is_match = 0
    last_pred_res = None  # store last prediction results for mismatch logging
    good_pred_result = None
    # Loop over all SQL queries extracted from the model output
    assert isinstance(item["completion"], str), (
        f"Completion filed in outputs file must be of type string: {item['completion']}. Type: {type(item['completion'])}"
    )
    for pred_sql in find_sql(item["completion"]):
        # Replace placeholder table names with the actual one
        pred_sql_fixed = fix_table_name(pred_sql, table_id)

        # Execute predicted SQL
        pred_res = execute_sql(conn, pred_sql_fixed)

        # Update metrics
        if pred_res is None:  # execution failed
            sql_error += 1
        elif pred_res == [(None,)]:  # returned NULL-equivalent
            pred_none += 1

        last_pred_res = pred_res

        # If both gold and prediction executed successfully and match → success
        if (
            gold_results is not None
            and pred_res is not None
            and gold_results == pred_res
        ):
            good_pred_result = pred_res
            is_match = 1

    # If no match was found, prepare mismatch details for debugging/logging
    mismatch_info = {
        "question_id": qid,
        "question": question_text,
        "gold_sql": gold_sql,
        "model_output": item["completion"],
        "gold_results": gold_results,
        "prediction_results": good_pred_result
        if good_pred_result is not None
        else last_pred_res,
    }

    return (
        is_match,
        mismatch_info,
        {"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error},
    )