import os import sqlite3 from pathlib import Path from typing import Any import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) log = logging.getLogger("llmsql-bench") import re def find_sql(model_output: str, limit: int = 10) -> list[str]: """Function to extract SQL queries from the model's response Args: model_output (str): Model's response as string limit (int, optional): The number of SQL queries to return. Defaults to 10. Returns: List[str]: SQL queries from input. """ results = [] # Find all SELECT keywords that could start a query for match in re.finditer(r"(?i)(SELECT)", model_output, re.IGNORECASE): start_pos = match.start(1) # Start of SELECT word # Look for end of query constraints from this position: # semicolumn, block of at least 4 newlines, markdown code fence or just end of string. remaining = model_output[start_pos:] query_match = re.search( r"(?s)SELECT\b.*?(?=(?:;|\n{4,}|```|$))", remaining, re.IGNORECASE, ) if query_match: query = query_match.group(0).strip() if query and query not in results: results.append(query) return results[:limit] def execute_sql(conn: sqlite3.Connection, sql: str) -> list[tuple] | None: """ Execute a SQL query on the given SQLite connection and return its results. The results are always sorted to avoid differences caused by row order (order agnostic). If the query fails, the function logs the error and returns None. Args: conn (sqlite3.Connection): An active SQLite database connection. sql (str): SQL query string to execute. Returns: Optional[List[Tuple]]: - Sorted list of result rows (each row as a tuple) if successful. - [(None,)] if the query executed but returned NULL values. - None if the SQL execution failed due to an exception. """ try: cur = conn.cursor() cur.execute(sql) results = cur.fetchall() return sorted(results) except Exception: return None def fix_table_name(sql: str, table_id: str) -> str: """ Replace placeholder table name in the SQL query with the actual table ID. During evaluation, the LLM is instructed to always generate queries using a generic placeholder table name (`FROM Table`, `FROM "Table"`, or `FROM 'Table'`). This keeps the model’s task simpler and avoids requiring it to memorize or reproduce arbitrary, dataset-specific table IDs. This function post-processes the model’s SQL output by replacing the placeholder with the true table identifier for the current question. Args: sql (str): SQL query string produced by the model, using "Table" as placeholder. table_id (str): Actual table name/identifier for the current question. Returns: str: SQL query with the correct table name substituted. """ return ( sql.replace("FROM 'Table'", f'FROM "{table_id}"') .replace('FROM "Table"', f'FROM "{table_id}"') .replace("FROM Table", f'FROM "{table_id}"') .strip() ) def evaluate_sample( item: dict[str, int | str], questions: dict[int, dict[str, str]], conn: sqlite3.Connection, ) -> tuple[int, dict[str, Any] | None, dict[Any, Any]]: """ Evaluate a single model prediction against the gold (ground-truth) SQL query. This function: - Retrieves the gold SQL query and question metadata for the given `question_id`. - Executes the gold SQL and the model's at most 10 predicted SQL queries on the SQLite DB. - Compares their results to determine whether the gold and at least one prediction are matched. - Tracks special cases such as SQL errors or queries returning NULL results. - Returns evaluation metrics and mismatch details (if any). Args: item (dict): A single model prediction entry. Must contain: - "question_id": ID of the benchmark question. - "completion": The raw SQL string predicted by the model. questions (dict): Dictionary mapping `question_id` → question metadata: {"sql": ..., "table_id": ..., "question": ...}. conn (sqlite3.Connection): Active SQLite connection used to run queries. Returns: tuple: is_match (int): 1 if prediction matches gold SQL results, else 0. mismatch_info (dict or None): Details about the mismatch if incorrect, otherwise None. Includes question, gold SQL, model output, and query results. metrics_update (dict): Partial metrics for this prediction: { "pred_none": int, "gold_none": int, "sql_error": int } """ qid = item["question_id"] assert isinstance(qid, int), ( "question_id in the outputs file needs to be of type int." ) q_info = questions[qid] table_id, gold_sql, question_text = ( q_info["table_id"], q_info["sql"], q_info["question"], ) gold_results = execute_sql(conn, gold_sql) pred_none = gold_none = sql_error = 0 if gold_results == [(None,)]: gold_none = 1 # Flag for whether the prediction was correct is_match = 0 last_pred_res = None # store last prediction results for mismatch logging good_pred_result = None # Loop over all SQL queries extracted from the model output assert isinstance(item["completion"], str), ( f"Completion filed in outputs file must be of type string: {item['completion']}. Type: {type(item['completion'])}" ) for pred_sql in find_sql(item["completion"]): # Replace placeholder table names with the actual one pred_sql_fixed = fix_table_name(pred_sql, table_id) # Execute predicted SQL pred_res = execute_sql(conn, pred_sql_fixed) # Update metrics if pred_res is None: # execution failed sql_error += 1 elif pred_res == [(None,)]: # returned NULL-equivalent pred_none += 1 last_pred_res = pred_res # If both gold and prediction executed successfully and match → success if ( gold_results is not None and pred_res is not None and gold_results == pred_res ): good_pred_result = pred_res is_match = 1 # If no match was found, prepare mismatch details for debugging/logging mismatch_info = { "question_id": qid, "question": question_text, "gold_sql": gold_sql, "model_output": item["completion"], "gold_results": gold_results, "prediction_results": good_pred_result if good_pred_result is not None else last_pred_res, } return ( is_match, mismatch_info, {"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error}, )