llmsql-interactive-q-a / evaluate.py
pihull's picture
Update evaluate.py
5f8f2ae verified
import os
import sqlite3
from pathlib import Path
from typing import Any
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
log = logging.getLogger("llmsql-bench")
import re
def find_sql(model_output: str, limit: int = 10) -> list[str]:
"""Function to extract SQL queries from the model's response
Args:
model_output (str): Model's response as string
limit (int, optional): The number of SQL queries to return. Defaults to 10.
Returns:
List[str]: SQL queries from input.
"""
results = []
# Find all SELECT keywords that could start a query
for match in re.finditer(r"(?i)(SELECT)", model_output, re.IGNORECASE):
start_pos = match.start(1) # Start of SELECT word
# Look for end of query constraints from this position:
# semicolumn, block of at least 4 newlines, markdown code fence or just end of string.
remaining = model_output[start_pos:]
query_match = re.search(
r"(?s)SELECT\b.*?(?=(?:;|\n{4,}|```|$))",
remaining,
re.IGNORECASE,
)
if query_match:
query = query_match.group(0).strip()
if query and query not in results:
results.append(query)
return results[:limit]
def execute_sql(conn: sqlite3.Connection, sql: str) -> list[tuple] | None:
"""
Execute a SQL query on the given SQLite connection and return its results.
The results are always sorted to avoid differences caused by row order (order agnostic).
If the query fails, the function logs the error and returns None.
Args:
conn (sqlite3.Connection): An active SQLite database connection.
sql (str): SQL query string to execute.
Returns:
Optional[List[Tuple]]:
- Sorted list of result rows (each row as a tuple) if successful.
- [(None,)] if the query executed but returned NULL values.
- None if the SQL execution failed due to an exception.
"""
try:
cur = conn.cursor()
cur.execute(sql)
results = cur.fetchall()
return sorted(results)
except Exception:
return None
def fix_table_name(sql: str, table_id: str) -> str:
"""
Replace placeholder table name in the SQL query with the actual table ID.
During evaluation, the LLM is instructed to always generate queries using
a generic placeholder table name (`FROM Table`, `FROM "Table"`, or `FROM 'Table'`).
This keeps the model’s task simpler and avoids requiring it to memorize or
reproduce arbitrary, dataset-specific table IDs.
This function post-processes the model’s SQL output by replacing the placeholder
with the true table identifier for the current question.
Args:
sql (str): SQL query string produced by the model, using "Table" as placeholder.
table_id (str): Actual table name/identifier for the current question.
Returns:
str: SQL query with the correct table name substituted.
"""
return (
sql.replace("FROM 'Table'", f'FROM "{table_id}"')
.replace('FROM "Table"', f'FROM "{table_id}"')
.replace("FROM Table", f'FROM "{table_id}"')
.strip()
)
def evaluate_sample(
item: dict[str, int | str],
questions: dict[int, dict[str, str]],
conn: sqlite3.Connection,
) -> tuple[int, dict[str, Any] | None, dict[Any, Any]]:
"""
Evaluate a single model prediction against the gold (ground-truth) SQL query.
This function:
- Retrieves the gold SQL query and question metadata for the given `question_id`.
- Executes the gold SQL and the model's at most 10 predicted SQL queries on the SQLite DB.
- Compares their results to determine whether the gold and at least one prediction are matched.
- Tracks special cases such as SQL errors or queries returning NULL results.
- Returns evaluation metrics and mismatch details (if any).
Args:
item (dict): A single model prediction entry. Must contain:
- "question_id": ID of the benchmark question.
- "completion": The raw SQL string predicted by the model.
questions (dict): Dictionary mapping `question_id` → question metadata:
{"sql": ..., "table_id": ..., "question": ...}.
conn (sqlite3.Connection): Active SQLite connection used to run queries.
Returns:
tuple:
is_match (int): 1 if prediction matches gold SQL results, else 0.
mismatch_info (dict or None): Details about the mismatch if incorrect,
otherwise None. Includes question, gold SQL,
model output, and query results.
metrics_update (dict): Partial metrics for this prediction:
{
"pred_none": int,
"gold_none": int,
"sql_error": int
}
"""
qid = item["question_id"]
assert isinstance(qid, int), (
"question_id in the outputs file needs to be of type int."
)
q_info = questions[qid]
table_id, gold_sql, question_text = (
q_info["table_id"],
q_info["sql"],
q_info["question"],
)
gold_results = execute_sql(conn, gold_sql)
pred_none = gold_none = sql_error = 0
if gold_results == [(None,)]:
gold_none = 1
# Flag for whether the prediction was correct
is_match = 0
last_pred_res = None # store last prediction results for mismatch logging
good_pred_result = None
# Loop over all SQL queries extracted from the model output
assert isinstance(item["completion"], str), (
f"Completion filed in outputs file must be of type string: {item['completion']}. Type: {type(item['completion'])}"
)
for pred_sql in find_sql(item["completion"]):
# Replace placeholder table names with the actual one
pred_sql_fixed = fix_table_name(pred_sql, table_id)
# Execute predicted SQL
pred_res = execute_sql(conn, pred_sql_fixed)
# Update metrics
if pred_res is None: # execution failed
sql_error += 1
elif pred_res == [(None,)]: # returned NULL-equivalent
pred_none += 1
last_pred_res = pred_res
# If both gold and prediction executed successfully and match → success
if (
gold_results is not None
and pred_res is not None
and gold_results == pred_res
):
good_pred_result = pred_res
is_match = 1
# If no match was found, prepare mismatch details for debugging/logging
mismatch_info = {
"question_id": qid,
"question": question_text,
"gold_sql": gold_sql,
"model_output": item["completion"],
"gold_results": gold_results,
"prediction_results": good_pred_result
if good_pred_result is not None
else last_pred_res,
}
return (
is_match,
mismatch_info,
{"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error},
)