Spaces:
Sleeping
Sleeping
File size: 7,323 Bytes
8206fba 575c346 8206fba 5f8f2ae 8206fba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import os
import sqlite3
from pathlib import Path
from typing import Any
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
log = logging.getLogger("llmsql-bench")
import re
def find_sql(model_output: str, limit: int = 10) -> list[str]:
"""Function to extract SQL queries from the model's response
Args:
model_output (str): Model's response as string
limit (int, optional): The number of SQL queries to return. Defaults to 10.
Returns:
List[str]: SQL queries from input.
"""
results = []
# Find all SELECT keywords that could start a query
for match in re.finditer(r"(?i)(SELECT)", model_output, re.IGNORECASE):
start_pos = match.start(1) # Start of SELECT word
# Look for end of query constraints from this position:
# semicolumn, block of at least 4 newlines, markdown code fence or just end of string.
remaining = model_output[start_pos:]
query_match = re.search(
r"(?s)SELECT\b.*?(?=(?:;|\n{4,}|```|$))",
remaining,
re.IGNORECASE,
)
if query_match:
query = query_match.group(0).strip()
if query and query not in results:
results.append(query)
return results[:limit]
def execute_sql(conn: sqlite3.Connection, sql: str) -> list[tuple] | None:
"""
Execute a SQL query on the given SQLite connection and return its results.
The results are always sorted to avoid differences caused by row order (order agnostic).
If the query fails, the function logs the error and returns None.
Args:
conn (sqlite3.Connection): An active SQLite database connection.
sql (str): SQL query string to execute.
Returns:
Optional[List[Tuple]]:
- Sorted list of result rows (each row as a tuple) if successful.
- [(None,)] if the query executed but returned NULL values.
- None if the SQL execution failed due to an exception.
"""
try:
cur = conn.cursor()
cur.execute(sql)
results = cur.fetchall()
return sorted(results)
except Exception:
return None
def fix_table_name(sql: str, table_id: str) -> str:
"""
Replace placeholder table name in the SQL query with the actual table ID.
During evaluation, the LLM is instructed to always generate queries using
a generic placeholder table name (`FROM Table`, `FROM "Table"`, or `FROM 'Table'`).
This keeps the model’s task simpler and avoids requiring it to memorize or
reproduce arbitrary, dataset-specific table IDs.
This function post-processes the model’s SQL output by replacing the placeholder
with the true table identifier for the current question.
Args:
sql (str): SQL query string produced by the model, using "Table" as placeholder.
table_id (str): Actual table name/identifier for the current question.
Returns:
str: SQL query with the correct table name substituted.
"""
return (
sql.replace("FROM 'Table'", f'FROM "{table_id}"')
.replace('FROM "Table"', f'FROM "{table_id}"')
.replace("FROM Table", f'FROM "{table_id}"')
.strip()
)
def evaluate_sample(
item: dict[str, int | str],
questions: dict[int, dict[str, str]],
conn: sqlite3.Connection,
) -> tuple[int, dict[str, Any] | None, dict[Any, Any]]:
"""
Evaluate a single model prediction against the gold (ground-truth) SQL query.
This function:
- Retrieves the gold SQL query and question metadata for the given `question_id`.
- Executes the gold SQL and the model's at most 10 predicted SQL queries on the SQLite DB.
- Compares their results to determine whether the gold and at least one prediction are matched.
- Tracks special cases such as SQL errors or queries returning NULL results.
- Returns evaluation metrics and mismatch details (if any).
Args:
item (dict): A single model prediction entry. Must contain:
- "question_id": ID of the benchmark question.
- "completion": The raw SQL string predicted by the model.
questions (dict): Dictionary mapping `question_id` → question metadata:
{"sql": ..., "table_id": ..., "question": ...}.
conn (sqlite3.Connection): Active SQLite connection used to run queries.
Returns:
tuple:
is_match (int): 1 if prediction matches gold SQL results, else 0.
mismatch_info (dict or None): Details about the mismatch if incorrect,
otherwise None. Includes question, gold SQL,
model output, and query results.
metrics_update (dict): Partial metrics for this prediction:
{
"pred_none": int,
"gold_none": int,
"sql_error": int
}
"""
qid = item["question_id"]
assert isinstance(qid, int), (
"question_id in the outputs file needs to be of type int."
)
q_info = questions[qid]
table_id, gold_sql, question_text = (
q_info["table_id"],
q_info["sql"],
q_info["question"],
)
gold_results = execute_sql(conn, gold_sql)
pred_none = gold_none = sql_error = 0
if gold_results == [(None,)]:
gold_none = 1
# Flag for whether the prediction was correct
is_match = 0
last_pred_res = None # store last prediction results for mismatch logging
good_pred_result = None
# Loop over all SQL queries extracted from the model output
assert isinstance(item["completion"], str), (
f"Completion filed in outputs file must be of type string: {item['completion']}. Type: {type(item['completion'])}"
)
for pred_sql in find_sql(item["completion"]):
# Replace placeholder table names with the actual one
pred_sql_fixed = fix_table_name(pred_sql, table_id)
# Execute predicted SQL
pred_res = execute_sql(conn, pred_sql_fixed)
# Update metrics
if pred_res is None: # execution failed
sql_error += 1
elif pred_res == [(None,)]: # returned NULL-equivalent
pred_none += 1
last_pred_res = pred_res
# If both gold and prediction executed successfully and match → success
if (
gold_results is not None
and pred_res is not None
and gold_results == pred_res
):
good_pred_result = pred_res
is_match = 1
# If no match was found, prepare mismatch details for debugging/logging
mismatch_info = {
"question_id": qid,
"question": question_text,
"gold_sql": gold_sql,
"model_output": item["completion"],
"gold_results": gold_results,
"prediction_results": good_pred_result
if good_pred_result is not None
else last_pred_res,
}
return (
is_match,
mismatch_info,
{"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error},
)
|