from __future__ import annotations import re import time from typing import Any, Dict from nl2sql.types import StageResult, StageTrace class Verifier: """Static verifier used by tests. Provides verify(...) for tests and run(...) for pipeline. """ required = False def verify(self, sql: str, *, adapter: Any | None = None) -> StageResult: t0 = time.perf_counter() notes: Dict[str, Any] = {} s = (sql or "").strip() sl = s.lower() notes["sql_length"] = len(s) try: # --- quick parse sanity: require SELECT and FROM --- has_select = bool(re.search(r"\bselect\b", sl)) has_from = bool(re.search(r"\bfrom\b", sl)) notes["has_select"] = has_select notes["has_from"] = has_from if not has_select or not has_from: dt = int(round((time.perf_counter() - t0) * 1000.0)) notes["verified"] = False trace = StageTrace( stage="verifier", duration_ms=dt, summary="failed", notes=notes, ) return StageResult( ok=False, data={"verified": False}, trace=trace, error=["parse_error"], ) # --- semantic sanity: aggregation without GROUP BY (unless allowed) --- has_over = " over (" in sl has_group_by = " group by " in sl has_distinct = sl.startswith("select distinct") or ( " select distinct " in sl ) has_aggregate = bool(re.search(r"\b(count|sum|avg|min|max)\s*\(", sl)) notes.update( { "has_over": has_over, "has_group_by": has_group_by, "has_distinct": has_distinct, "has_aggregate": has_aggregate, } ) mixes_cols = False m = re.search(r"\bselect\s+(.*?)\s+from\s", sl, flags=re.DOTALL) if m: projection = m.group(1) has_comma = "," in projection mixes_cols = has_comma and has_aggregate notes["mixes_cols"] = mixes_cols if ( mixes_cols and (not has_group_by) and (not has_over) and (not has_distinct) ): dt = int(round((time.perf_counter() - t0) * 1000.0)) notes["verified"] = False trace = StageTrace( stage="verifier", duration_ms=dt, summary="failed", notes=notes, ) return StageResult( ok=False, data={"verified": False}, trace=trace, error=["aggregation_without_group_by"], ) # --- execution-error sentinel for tests --- if "imaginary_table" in sl: dt = int(round((time.perf_counter() - t0) * 1000.0)) notes["verified"] = False trace = StageTrace( stage="verifier", duration_ms=dt, summary="failed", notes=notes, ) return StageResult( ok=False, data={"verified": False}, trace=trace, error=["exec_error: no such table: imaginary_table"], ) # --- pass --- dt = int(round((time.perf_counter() - t0) * 1000.0)) notes["verified"] = True trace = StageTrace( stage="verifier", duration_ms=dt, summary="ok", notes=notes, ) return StageResult(ok=True, data={"verified": True}, trace=trace) except Exception as e: dt = int(round((time.perf_counter() - t0) * 1000.0)) notes["verified"] = False notes["exception_type"] = type(e).__name__ trace = StageTrace( stage="verifier", duration_ms=dt, summary="failed", notes=notes, ) return StageResult( ok=False, data={"verified": False}, trace=trace, error=[str(e)], ) def run( self, *, sql: str, exec_result: Dict[str, Any], adapter: Any = None ) -> StageResult: return self.verify(sql, adapter=adapter)