Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sqlite3 | |
| from pathlib import Path | |
| from typing import Any | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| log = logging.getLogger("llmsql-bench") | |
| import re | |
| def find_sql(model_output: str, limit: int = 10) -> list[str]: | |
| """Function to extract SQL queries from the model's response | |
| Args: | |
| model_output (str): Model's response as string | |
| limit (int, optional): The number of SQL queries to return. Defaults to 10. | |
| Returns: | |
| List[str]: SQL queries from input. | |
| """ | |
| results = [] | |
| # Find all SELECT keywords that could start a query | |
| for match in re.finditer(r"(?i)(SELECT)", model_output, re.IGNORECASE): | |
| start_pos = match.start(1) # Start of SELECT word | |
| # Look for end of query constraints from this position: | |
| # semicolumn, block of at least 4 newlines, markdown code fence or just end of string. | |
| remaining = model_output[start_pos:] | |
| query_match = re.search( | |
| r"(?s)SELECT\b.*?(?=(?:;|\n{4,}|```|$))", | |
| remaining, | |
| re.IGNORECASE, | |
| ) | |
| if query_match: | |
| query = query_match.group(0).strip() | |
| if query and query not in results: | |
| results.append(query) | |
| return results[:limit] | |
| def execute_sql(conn: sqlite3.Connection, sql: str) -> list[tuple] | None: | |
| """ | |
| Execute a SQL query on the given SQLite connection and return its results. | |
| The results are always sorted to avoid differences caused by row order (order agnostic). | |
| If the query fails, the function logs the error and returns None. | |
| Args: | |
| conn (sqlite3.Connection): An active SQLite database connection. | |
| sql (str): SQL query string to execute. | |
| Returns: | |
| Optional[List[Tuple]]: | |
| - Sorted list of result rows (each row as a tuple) if successful. | |
| - [(None,)] if the query executed but returned NULL values. | |
| - None if the SQL execution failed due to an exception. | |
| """ | |
| try: | |
| cur = conn.cursor() | |
| cur.execute(sql) | |
| results = cur.fetchall() | |
| return sorted(results) | |
| except Exception: | |
| return None | |
| def fix_table_name(sql: str, table_id: str) -> str: | |
| """ | |
| Replace placeholder table name in the SQL query with the actual table ID. | |
| During evaluation, the LLM is instructed to always generate queries using | |
| a generic placeholder table name (`FROM Table`, `FROM "Table"`, or `FROM 'Table'`). | |
| This keeps the model’s task simpler and avoids requiring it to memorize or | |
| reproduce arbitrary, dataset-specific table IDs. | |
| This function post-processes the model’s SQL output by replacing the placeholder | |
| with the true table identifier for the current question. | |
| Args: | |
| sql (str): SQL query string produced by the model, using "Table" as placeholder. | |
| table_id (str): Actual table name/identifier for the current question. | |
| Returns: | |
| str: SQL query with the correct table name substituted. | |
| """ | |
| return ( | |
| sql.replace("FROM 'Table'", f'FROM "{table_id}"') | |
| .replace('FROM "Table"', f'FROM "{table_id}"') | |
| .replace("FROM Table", f'FROM "{table_id}"') | |
| .strip() | |
| ) | |
| def evaluate_sample( | |
| item: dict[str, int | str], | |
| questions: dict[int, dict[str, str]], | |
| conn: sqlite3.Connection, | |
| ) -> tuple[int, dict[str, Any] | None, dict[Any, Any]]: | |
| """ | |
| Evaluate a single model prediction against the gold (ground-truth) SQL query. | |
| This function: | |
| - Retrieves the gold SQL query and question metadata for the given `question_id`. | |
| - Executes the gold SQL and the model's at most 10 predicted SQL queries on the SQLite DB. | |
| - Compares their results to determine whether the gold and at least one prediction are matched. | |
| - Tracks special cases such as SQL errors or queries returning NULL results. | |
| - Returns evaluation metrics and mismatch details (if any). | |
| Args: | |
| item (dict): A single model prediction entry. Must contain: | |
| - "question_id": ID of the benchmark question. | |
| - "completion": The raw SQL string predicted by the model. | |
| questions (dict): Dictionary mapping `question_id` → question metadata: | |
| {"sql": ..., "table_id": ..., "question": ...}. | |
| conn (sqlite3.Connection): Active SQLite connection used to run queries. | |
| Returns: | |
| tuple: | |
| is_match (int): 1 if prediction matches gold SQL results, else 0. | |
| mismatch_info (dict or None): Details about the mismatch if incorrect, | |
| otherwise None. Includes question, gold SQL, | |
| model output, and query results. | |
| metrics_update (dict): Partial metrics for this prediction: | |
| { | |
| "pred_none": int, | |
| "gold_none": int, | |
| "sql_error": int | |
| } | |
| """ | |
| qid = item["question_id"] | |
| assert isinstance(qid, int), ( | |
| "question_id in the outputs file needs to be of type int." | |
| ) | |
| q_info = questions[qid] | |
| table_id, gold_sql, question_text = ( | |
| q_info["table_id"], | |
| q_info["sql"], | |
| q_info["question"], | |
| ) | |
| gold_results = execute_sql(conn, gold_sql) | |
| pred_none = gold_none = sql_error = 0 | |
| if gold_results == [(None,)]: | |
| gold_none = 1 | |
| # Flag for whether the prediction was correct | |
| is_match = 0 | |
| last_pred_res = None # store last prediction results for mismatch logging | |
| good_pred_result = None | |
| # Loop over all SQL queries extracted from the model output | |
| assert isinstance(item["completion"], str), ( | |
| f"Completion filed in outputs file must be of type string: {item['completion']}. Type: {type(item['completion'])}" | |
| ) | |
| for pred_sql in find_sql(item["completion"]): | |
| # Replace placeholder table names with the actual one | |
| pred_sql_fixed = fix_table_name(pred_sql, table_id) | |
| # Execute predicted SQL | |
| pred_res = execute_sql(conn, pred_sql_fixed) | |
| # Update metrics | |
| if pred_res is None: # execution failed | |
| sql_error += 1 | |
| elif pred_res == [(None,)]: # returned NULL-equivalent | |
| pred_none += 1 | |
| last_pred_res = pred_res | |
| # If both gold and prediction executed successfully and match → success | |
| if ( | |
| gold_results is not None | |
| and pred_res is not None | |
| and gold_results == pred_res | |
| ): | |
| good_pred_result = pred_res | |
| is_match = 1 | |
| # If no match was found, prepare mismatch details for debugging/logging | |
| mismatch_info = { | |
| "question_id": qid, | |
| "question": question_text, | |
| "gold_sql": gold_sql, | |
| "model_output": item["completion"], | |
| "gold_results": gold_results, | |
| "prediction_results": good_pred_result | |
| if good_pred_result is not None | |
| else last_pred_res, | |
| } | |
| return ( | |
| is_match, | |
| mismatch_info, | |
| {"pred_none": pred_none, "gold_none": gold_none, "sql_error": sql_error}, | |
| ) | |