from __future__ import annotations import os import re import sqlite3 import time from dataclasses import dataclass from typing import List, Optional, Sequence, Set, Tuple, Union try: import sqlparse from sqlparse.sql import Function, Identifier, IdentifierList, Statement, Token, Where from sqlparse.tokens import DML, Keyword, Name, Number, Punctuation, String, Whitespace except Exception: # pragma: no cover sqlparse = None # type: ignore[assignment] Statement = object # type: ignore[misc,assignment] Token = object # type: ignore[misc,assignment] def _normalize_sql(sql: str) -> str: if not isinstance(sql, str): return "" s = sql.strip() if s.startswith("```"): # Strip markdown fences if present. s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip() s = re.sub(r"\n?```$", "", s).strip() if s.lower().startswith("sql:"): s = s[4:].strip() # Keep only the first statement to avoid accidental multi-statement execution. if ";" in s: s = s.split(";", 1)[0].strip() return s def _connect_readonly(db_path: str) -> sqlite3.Connection: # Read-only prevents any accidental mutation during reward computation. # Note: requires SQLite URI support (built-in). uri = f"file:{os.path.abspath(db_path)}?mode=ro" conn = sqlite3.connect(uri, uri=True, check_same_thread=False) conn.execute("PRAGMA query_only = ON;") conn.execute("PRAGMA foreign_keys = ON;") return conn def _with_timeout(conn: sqlite3.Connection, timeout_s: float = 1.0) -> None: start = time.monotonic() def _handler() -> int: return 1 if (time.monotonic() - start) > timeout_s else 0 # Call handler every N VM opcodes. conn.set_progress_handler(_handler, 10_000) def _list_tables(conn: sqlite3.Connection) -> List[str]: try: cur = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" ) return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)] except sqlite3.Error: return [] def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool: s = sql.lower() for t in table_names: tl = t.lower() if not tl: continue if re.search(rf"\b{re.escape(tl)}\b", s): return True return False def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool: try: _with_timeout(conn, timeout_s=1.0) conn.execute(f"EXPLAIN QUERY PLAN {sql}") return True except sqlite3.Error: return False def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]: try: _with_timeout(conn, timeout_s=1.0) cur = conn.execute(sql) rows = cur.fetchmany(max_rows) # Normalize to plain tuples for deterministic comparison. norm_rows = [tuple(r) for r in rows] return True, norm_rows, None except sqlite3.Error as e: return False, [], str(e) _SQL_KEYWORDS_TO_IGNORE = { "select", "from", "where", "join", "inner", "left", "right", "full", "outer", "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect", "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case", "when", "then", "else", "end", "asc", "desc", } _SQL_FUNCTIONS_TO_IGNORE = { "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce", "round", "date", "datetime", "strftime", } def extract_tables(sql: str) -> Set[str]: """ Best-effort table extraction from SQL using sqlparse. Returns lowercase table names (unqualified). """ sql = _normalize_sql(sql) if not sql: return set() if sqlparse is None: # Fallback: naive regex for FROM/JOIN. found = set() for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I): found.add(m.group(2).lower()) return found try: statements = sqlparse.parse(sql) except Exception: return set() tables: Set[str] = set() def _add_identifier_as_table(ident: Identifier) -> None: # Prefer real name over alias; strip any schema prefix. name = ident.get_real_name() or ident.get_name() if not name: return tables.add(name.lower()) for st in statements: if not isinstance(st, Statement): continue seen_from = False for tok in st.flatten(): if tok.ttype in Whitespace: continue if tok.ttype is Keyword and tok.value.upper() in {"FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN"}: seen_from = True continue if not seen_from: continue if isinstance(tok, Identifier): _add_identifier_as_table(tok) seen_from = False elif tok.ttype is Name: tables.add(tok.value.lower()) seen_from = False elif tok.ttype is Keyword and tok.value.upper() in {"WHERE", "GROUP", "ORDER", "HAVING", "LIMIT"}: seen_from = False return tables def extract_columns(sql: str) -> Set[str]: """ Best-effort column extraction from SQL using sqlparse. Returns lowercase column names (unqualified). """ sql = _normalize_sql(sql) if not sql: return set() if sqlparse is None: # Fallback: naive dotted identifiers and bare names after SELECT/WHERE/etc. cols = set() for m in re.finditer(r"\b([a-zA-Z_][\w$]*)\b", sql): w = m.group(1).lower() if w in _SQL_KEYWORDS_TO_IGNORE or w in _SQL_FUNCTIONS_TO_IGNORE: continue cols.add(w) return cols try: statements = sqlparse.parse(sql) except Exception: return set() cols: Set[str] = set() def _maybe_add_col(name: Optional[str]) -> None: if not name: return n = name.strip().strip('"').strip("'").lower() if not n or n == "*": return if n in _SQL_KEYWORDS_TO_IGNORE or n in _SQL_FUNCTIONS_TO_IGNORE: return cols.add(n) def _handle_identifier(ident: Identifier) -> None: # If qualified (t.col), keep only col for overlap/hallucination checks. _maybe_add_col(ident.get_real_name() or ident.get_name()) for st in statements: if not isinstance(st, Statement): continue for tok in st.flatten(): # Skip whitespace/punctuation/string literals/numbers. if getattr(tok, "ttype", None) in (Whitespace, Punctuation, String, Number): continue if isinstance(tok, Function): fname = tok.get_name() if fname: # Don't treat function name as a column. pass continue if isinstance(tok, IdentifierList): for ident in tok.get_identifiers(): if isinstance(ident, Identifier): _handle_identifier(ident) continue if isinstance(tok, Identifier): _handle_identifier(tok) continue if getattr(tok, "ttype", None) is Name: _maybe_add_col(tok.value) return cols def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]: """ Return (tables, columns) sets from SQLite schema; all lowercased. Columns are returned as a global set (unqualified). """ tables = set() columns = set() for t in _list_tables(conn): tl = t.lower() if not tl: continue tables.add(tl) try: cur = conn.execute(f'PRAGMA table_info("{t}")') for row in cur.fetchall(): if row and isinstance(row[1], str): columns.add(row[1].lower()) except sqlite3.Error: continue return tables, columns def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool: # Deterministic comparison: compare exact row tuples in order. return a == b @dataclass class RewardDebugStats: total: int = 0 parsed_ok: int = 0 table_match: int = 0 column_match: int = 0 executed_ok: int = 0 exact_match: int = 0 _DEBUG = RewardDebugStats() def reset_debug_metrics() -> None: global _DEBUG _DEBUG = RewardDebugStats() def get_debug_metrics() -> dict: denom = max(_DEBUG.total, 1) return { "valid_sql_rate": _DEBUG.parsed_ok / denom, "table_match_rate": _DEBUG.table_match / denom, "column_match_rate": _DEBUG.column_match / denom, "execution_accuracy": _DEBUG.exact_match / denom, } EXECUTION_ERROR = "EXECUTION_ERROR" def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]: """ Execute SQL safely. If sqlite raises ANY exception, return EXECUTION_ERROR (NOT empty list). """ try: _with_timeout(conn, timeout_s=1.0) cur = conn.execute(sql) rows = cur.fetchmany(max_rows) return [tuple(r) for r in rows] except Exception: return EXECUTION_ERROR def _sqlparse_valid_select(sql: str) -> bool: """ Parse validation using sqlparse: - parse() non-empty - contains a SELECT statement """ if sqlparse is None: return False try: stmts = sqlparse.parse(sql) if not stmts: return False for st in stmts: try: if hasattr(st, "get_type") and st.get_type() == "SELECT": return True except Exception: continue return False except Exception: return False def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float: try: sql = _normalize_sql(pred_sql) gold = _normalize_sql(gold_sql) if not sql or "SELECT" not in sql.upper(): return -1.0 if not _sqlparse_valid_select(sql): return -1.0 reward = -0.2 # valid SQL baseline pred_tables = extract_tables(sql) gold_tables = extract_tables(gold) if pred_tables == gold_tables and len(gold_tables) > 0: reward += 0.3 pred_cols = extract_columns(sql) gold_cols = extract_columns(gold) if gold_cols: overlap = len(pred_cols & gold_cols) / len(gold_cols) reward += 0.3 * overlap with _connect_readonly(db_path) as conn: pred_res = execute_sql(conn, sql) if pred_res != EXECUTION_ERROR: reward += 0.2 gold_res = execute_sql(conn, gold) if pred_res != EXECUTION_ERROR and _safe_results_equal(pred_res, gold_res): return 1.0 return max(-1.0, min(1.0, reward)) except Exception: return -1.0