Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
b794494
1
Parent(s):
db1d448
feat(core): refine pipeline & verifier; improve Spider benchmark accuracy
Browse files- adapters/llm/openai_provider.py +305 -77
- benchmarks/evaluate_spider_pro.py +387 -431
- benchmarks/results_pro/20251108-123204/eval.jsonl +0 -5
- benchmarks/results_pro/20251108-123204/latency_per_stage.png +0 -0
- benchmarks/results_pro/20251108-123204/metrics_overview.png +0 -0
- benchmarks/results_pro/20251108-123204/results.csv +0 -6
- benchmarks/results_pro/20251108-123204/summary.json +0 -13
- benchmarks/results_pro/20251108-124153/eval.jsonl +0 -5
- benchmarks/results_pro/20251108-124153/latency_per_stage.png +0 -0
- benchmarks/results_pro/20251108-124153/metrics_overview.png +0 -0
- benchmarks/results_pro/20251108-124153/results.csv +0 -6
- benchmarks/results_pro/20251108-124153/summary.json +0 -13
- benchmarks/results_pro/20251108-125829/eval.jsonl +0 -5
- benchmarks/results_pro/20251108-125829/latency_per_stage.png +0 -0
- benchmarks/results_pro/20251108-125829/metrics_overview.png +0 -0
- benchmarks/results_pro/20251108-125829/results.csv +0 -6
- benchmarks/results_pro/20251108-125829/summary.json +0 -13
- benchmarks/results_pro/20251109-092540/eval.jsonl +5 -0
- benchmarks/results_pro/20251109-092540/summary.json +12 -0
- benchmarks/results_pro/20251109-092823/eval.jsonl +5 -0
- benchmarks/results_pro/20251109-092823/summary.json +12 -0
- benchmarks/results_pro/20251109-093743/eval.jsonl +5 -0
- benchmarks/results_pro/20251109-093743/summary.json +12 -0
- nl2sql/pipeline.py +137 -118
- nl2sql/verifier.py +266 -174
adapters/llm/openai_provider.py
CHANGED
|
@@ -1,24 +1,16 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
| 3 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from adapters.llm.base import LLMProvider
|
| 5 |
from openai import OpenAI
|
| 6 |
|
| 7 |
-
# NOTE:
|
| 8 |
-
# - Prefer proxy if PROXY_API_KEY and PROXY_BASE_URL are set.
|
| 9 |
-
# - Otherwise, fallback to OPENAI_API_KEY (+ OPENAI_BASE_URL defaulting to https://api.openai.com/v1).
|
| 10 |
-
# - Do NOT pass base_url/api_key in the constructor; rely on env vars.
|
| 11 |
-
|
| 12 |
|
| 13 |
def _resolve_api_config() -> tuple[str, str, str]:
|
| 14 |
-
"""
|
| 15 |
-
Returns (api_key, base_url, model_id) according to env.
|
| 16 |
-
Resolution order:
|
| 17 |
-
1) Proxy: PROXY_API_KEY + PROXY_BASE_URL [+ PROXY_MODEL_ID]
|
| 18 |
-
2) Direct: OPENAI_API_KEY [+ OPENAI_BASE_URL] [+ OPENAI_MODEL_ID]
|
| 19 |
-
Additionally, LLM_MODEL_ID (if set) overrides model choice.
|
| 20 |
-
"""
|
| 21 |
-
# Optional global override for model id
|
| 22 |
override_model = os.getenv("LLM_MODEL_ID")
|
| 23 |
|
| 24 |
proxy_key = os.getenv("PROXY_API_KEY")
|
|
@@ -43,74 +35,146 @@ def _resolve_api_config() -> tuple[str, str, str]:
|
|
| 43 |
|
| 44 |
|
| 45 |
class OpenAIProvider(LLMProvider):
|
|
|
|
|
|
|
| 46 |
provider_id = "openai"
|
| 47 |
|
| 48 |
def __init__(self) -> None:
|
| 49 |
-
|
| 50 |
api_key, base_url, model = _resolve_api_config()
|
| 51 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 52 |
os.environ["OPENAI_BASE_URL"] = base_url
|
| 53 |
-
# Create client using env only
|
| 54 |
self.client = OpenAI()
|
| 55 |
self.model = model
|
| 56 |
|
| 57 |
-
def plan(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
completion = self.client.chat.completions.create(
|
| 59 |
model=self.model,
|
| 60 |
messages=[
|
| 61 |
-
{"role": "system", "content":
|
| 62 |
-
{
|
| 63 |
-
"role": "user",
|
| 64 |
-
"content": f"Query: {user_query}\nSchema:\n{schema_preview}",
|
| 65 |
-
},
|
| 66 |
],
|
| 67 |
-
temperature=0,
|
| 68 |
)
|
| 69 |
-
|
|
|
|
| 70 |
usage = completion.usage
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
usage.prompt_tokens
|
| 74 |
-
usage.completion_tokens
|
| 75 |
-
self._estimate_cost(usage)
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def generate_sql(
|
| 79 |
-
self,
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
{schema_preview}
|
| 97 |
-
Plan: {plan_text}
|
| 98 |
-
Clarifications: {clarify_answers}
|
| 99 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
completion = self.client.chat.completions.create(
|
| 101 |
model=self.model,
|
| 102 |
messages=[
|
| 103 |
-
{"role": "system", "content":
|
| 104 |
-
{"role": "user", "content":
|
| 105 |
],
|
| 106 |
-
temperature=0,
|
|
|
|
| 107 |
)
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
usage = completion.usage
|
| 110 |
-
t_in = usage.prompt_tokens if usage else None
|
| 111 |
-
t_out = usage.completion_tokens if usage else None
|
| 112 |
-
cost = self._estimate_cost(usage) if usage else None
|
| 113 |
|
|
|
|
| 114 |
try:
|
| 115 |
parsed = json.loads(content)
|
| 116 |
except json.JSONDecodeError:
|
|
@@ -126,35 +190,199 @@ class OpenAIProvider(LLMProvider):
|
|
| 126 |
|
| 127 |
sql = (parsed.get("sql") or "").strip()
|
| 128 |
rationale = parsed.get("rationale") or ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if not sql:
|
| 130 |
raise ValueError("LLM returned empty 'sql'")
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
def repair(self, *, sql, error_msg, schema_preview):
|
| 135 |
completion = self.client.chat.completions.create(
|
| 136 |
model=self.model,
|
| 137 |
messages=[
|
| 138 |
-
{
|
| 139 |
-
|
| 140 |
-
"content": "You fix SQL queries keeping them SELECT-only.",
|
| 141 |
-
},
|
| 142 |
-
{
|
| 143 |
-
"role": "user",
|
| 144 |
-
"content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}",
|
| 145 |
-
},
|
| 146 |
],
|
| 147 |
-
temperature=0,
|
| 148 |
)
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
usage = completion.usage
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
usage.prompt_tokens
|
| 154 |
-
usage.completion_tokens
|
| 155 |
-
self._estimate_cost(usage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
)
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import json
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, List, Tuple
|
| 7 |
+
|
| 8 |
from adapters.llm.base import LLMProvider
|
| 9 |
from openai import OpenAI
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def _resolve_api_config() -> tuple[str, str, str]:
|
| 13 |
+
"""Returns (api_key, base_url, model_id) according to env."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
override_model = os.getenv("LLM_MODEL_ID")
|
| 15 |
|
| 16 |
proxy_key = os.getenv("PROXY_API_KEY")
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
class OpenAIProvider(LLMProvider):
|
| 38 |
+
"""OpenAI LLM provider implementation."""
|
| 39 |
+
|
| 40 |
provider_id = "openai"
|
| 41 |
|
| 42 |
def __init__(self) -> None:
|
| 43 |
+
"""Initialize OpenAI client with config from environment."""
|
| 44 |
api_key, base_url, model = _resolve_api_config()
|
| 45 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 46 |
os.environ["OPENAI_BASE_URL"] = base_url
|
|
|
|
| 47 |
self.client = OpenAI()
|
| 48 |
self.model = model
|
| 49 |
|
| 50 |
+
def plan(
|
| 51 |
+
self, *, user_query: str, schema_preview: str
|
| 52 |
+
) -> Tuple[str, int, int, float]:
|
| 53 |
+
"""Generate a query plan for the SQL generation.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
user_query: The user's natural language question
|
| 57 |
+
schema_preview: Database schema information
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tuple of (plan_text, prompt_tokens, completion_tokens, cost)
|
| 61 |
+
"""
|
| 62 |
+
system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan.
|
| 63 |
+
|
| 64 |
+
Your plan should:
|
| 65 |
+
1. Identify the tables and columns needed
|
| 66 |
+
2. Determine any JOINs required
|
| 67 |
+
3. Specify filtering conditions (WHERE)
|
| 68 |
+
4. Identify aggregations (GROUP BY, COUNT, etc.)
|
| 69 |
+
5. Note sorting requirements (ORDER BY)
|
| 70 |
+
6. Check for special cases (DISTINCT, LIMIT, etc.)
|
| 71 |
+
|
| 72 |
+
Be concise but thorough."""
|
| 73 |
+
|
| 74 |
+
user_prompt = f"""Question: {user_query}
|
| 75 |
+
|
| 76 |
+
Database Schema:
|
| 77 |
+
{schema_preview}
|
| 78 |
+
|
| 79 |
+
Create a step-by-step plan to answer this question with SQL."""
|
| 80 |
+
|
| 81 |
completion = self.client.chat.completions.create(
|
| 82 |
model=self.model,
|
| 83 |
messages=[
|
| 84 |
+
{"role": "system", "content": system_prompt},
|
| 85 |
+
{"role": "user", "content": user_prompt},
|
|
|
|
|
|
|
|
|
|
| 86 |
],
|
| 87 |
+
temperature=0.1,
|
| 88 |
)
|
| 89 |
+
|
| 90 |
+
msg = completion.choices[0].message.content or ""
|
| 91 |
usage = completion.usage
|
| 92 |
+
|
| 93 |
+
if usage:
|
| 94 |
+
prompt_tokens = usage.prompt_tokens
|
| 95 |
+
completion_tokens = usage.completion_tokens
|
| 96 |
+
cost = self._estimate_cost(usage)
|
| 97 |
+
return (msg, prompt_tokens, completion_tokens, cost)
|
| 98 |
+
else:
|
| 99 |
+
return (msg, 0, 0, 0.0)
|
| 100 |
|
| 101 |
def generate_sql(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
user_query: str,
|
| 105 |
+
schema_preview: str,
|
| 106 |
+
plan_text: str,
|
| 107 |
+
clarify_answers: dict[str, Any] | None = None,
|
| 108 |
+
) -> Tuple[str, str, int, int, float]:
|
| 109 |
+
"""Generate SQL with improved prompt for Spider benchmark.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
user_query: The user's natural language question
|
| 113 |
+
schema_preview: Database schema information
|
| 114 |
+
plan_text: Query execution plan
|
| 115 |
+
clarify_answers: Optional additional context
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
|
|
|
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
+
system_prompt = """You are an expert SQL query generator for SQLite databases.
|
| 121 |
+
You must follow these STRICT rules to generate clean, simple SQL:
|
| 122 |
+
|
| 123 |
+
CRITICAL RULES:
|
| 124 |
+
1. Write the SIMPLEST possible SQL that answers the question
|
| 125 |
+
2. NEVER use table prefixes unless absolutely necessary for disambiguation
|
| 126 |
+
3. NEVER add aliases (AS) unless specifically requested
|
| 127 |
+
4. NEVER add LIMIT unless the question asks for a specific number of results
|
| 128 |
+
5. NEVER use DISTINCT with COUNT(*) unless explicitly needed
|
| 129 |
+
6. Use lowercase for SQL keywords (select, from, where, etc.)
|
| 130 |
+
7. Do not add unnecessary parentheses or formatting
|
| 131 |
+
8. Match exact column and table names from the schema (case-sensitive)
|
| 132 |
+
|
| 133 |
+
IMPORTANT:
|
| 134 |
+
- For counting all rows: Use COUNT(*) not COUNT(column_name)
|
| 135 |
+
- For ordering: Only add ORDER BY if the question asks for sorted results
|
| 136 |
+
- Keep the SQL as close as possible to the minimal required syntax
|
| 137 |
+
|
| 138 |
+
You must return ONLY valid JSON with exactly two keys: "sql" and "rationale".
|
| 139 |
+
The SQL should be a single line without unnecessary spaces."""
|
| 140 |
+
|
| 141 |
+
user_prompt = f"""Based on this information, generate a simple SQL query:
|
| 142 |
+
|
| 143 |
+
Question: {user_query}
|
| 144 |
+
|
| 145 |
+
Database Schema:
|
| 146 |
+
{schema_preview}
|
| 147 |
+
|
| 148 |
+
Query Plan:
|
| 149 |
+
{plan_text}
|
| 150 |
+
|
| 151 |
+
Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses.
|
| 152 |
+
|
| 153 |
+
Example of what we want:
|
| 154 |
+
Question: "How many singers are there?"
|
| 155 |
+
Correct: {{"sql": "select count(*) from singer", "rationale": "Count all rows in singer table"}}
|
| 156 |
+
Wrong: {{"sql": "SELECT COUNT(singer.singer_id) AS total_singers FROM singer", "rationale": "..."}}
|
| 157 |
+
|
| 158 |
+
Now generate the SQL for the given question:"""
|
| 159 |
+
|
| 160 |
+
if clarify_answers:
|
| 161 |
+
user_prompt += f"\n\nAdditional context: {clarify_answers}"
|
| 162 |
+
|
| 163 |
completion = self.client.chat.completions.create(
|
| 164 |
model=self.model,
|
| 165 |
messages=[
|
| 166 |
+
{"role": "system", "content": system_prompt},
|
| 167 |
+
{"role": "user", "content": user_prompt},
|
| 168 |
],
|
| 169 |
+
temperature=0.1,
|
| 170 |
+
max_tokens=500,
|
| 171 |
)
|
| 172 |
+
|
| 173 |
+
text = completion.choices[0].message.content
|
| 174 |
+
content = text.strip() if text else ""
|
| 175 |
usage = completion.usage
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
# Parse JSON response
|
| 178 |
try:
|
| 179 |
parsed = json.loads(content)
|
| 180 |
except json.JSONDecodeError:
|
|
|
|
| 190 |
|
| 191 |
sql = (parsed.get("sql") or "").strip()
|
| 192 |
rationale = parsed.get("rationale") or ""
|
| 193 |
+
|
| 194 |
+
# Post-process SQL to ensure simplicity
|
| 195 |
+
sql = self._simplify_sql(sql)
|
| 196 |
+
|
| 197 |
if not sql:
|
| 198 |
raise ValueError("LLM returned empty 'sql'")
|
| 199 |
|
| 200 |
+
if usage:
|
| 201 |
+
prompt_tokens = usage.prompt_tokens
|
| 202 |
+
completion_tokens = usage.completion_tokens
|
| 203 |
+
cost = self._estimate_cost(usage)
|
| 204 |
+
return (sql, rationale, prompt_tokens, completion_tokens, cost)
|
| 205 |
+
else:
|
| 206 |
+
return (sql, rationale, 0, 0, 0.0)
|
| 207 |
+
|
| 208 |
+
def _simplify_sql(self, sql: str) -> str:
|
| 209 |
+
"""Post-process SQL to remove common unnecessary additions."""
|
| 210 |
+
if not sql:
|
| 211 |
+
return sql
|
| 212 |
+
|
| 213 |
+
# Remove trailing semicolon
|
| 214 |
+
sql = sql.rstrip(";")
|
| 215 |
+
|
| 216 |
+
# Remove unnecessary table prefixes in simple queries
|
| 217 |
+
# e.g., "singer.name" -> "name" when there's only one table
|
| 218 |
+
if sql.lower().count(" from ") == 1 and " join " not in sql.lower():
|
| 219 |
+
match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE)
|
| 220 |
+
if match:
|
| 221 |
+
table = match.group(1)
|
| 222 |
+
sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql)
|
| 223 |
+
|
| 224 |
+
# Remove unnecessary DISTINCT in COUNT(*)
|
| 225 |
+
sql = re.sub(
|
| 226 |
+
r"count\s*\(\s*distinct\s+\*\s*\)",
|
| 227 |
+
"count(*)",
|
| 228 |
+
sql,
|
| 229 |
+
flags=re.IGNORECASE,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Remove big default LIMITs that weren't requested
|
| 233 |
+
sql = re.sub(
|
| 234 |
+
r"\s+limit\s+(100|1000|10000)\b",
|
| 235 |
+
"",
|
| 236 |
+
sql,
|
| 237 |
+
flags=re.IGNORECASE,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return sql
|
| 241 |
+
|
| 242 |
+
def repair(
|
| 243 |
+
self,
|
| 244 |
+
*,
|
| 245 |
+
sql: str,
|
| 246 |
+
error_msg: str,
|
| 247 |
+
schema_preview: str,
|
| 248 |
+
) -> Tuple[str, int, int, float]:
|
| 249 |
+
"""Repair SQL with focus on simplicity.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
sql: Broken SQL query
|
| 253 |
+
error_msg: Error message from execution
|
| 254 |
+
schema_preview: Database schema information
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
Tuple of (fixed_sql, prompt_tokens, completion_tokens, cost)
|
| 258 |
+
"""
|
| 259 |
+
system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error.
|
| 260 |
+
|
| 261 |
+
IMPORTANT RULES:
|
| 262 |
+
1. Keep the fix as minimal as possible
|
| 263 |
+
2. Don't add complexity - keep it simple
|
| 264 |
+
3. Preserve the original intent of the query
|
| 265 |
+
4. Follow SQLite syntax rules
|
| 266 |
+
5. Don't add aliases or table prefixes unless necessary
|
| 267 |
+
|
| 268 |
+
Return ONLY the corrected SQL query, nothing else."""
|
| 269 |
+
|
| 270 |
+
user_prompt = f"""Fix this SQL query:
|
| 271 |
+
|
| 272 |
+
Original SQL: {sql}
|
| 273 |
+
|
| 274 |
+
Error: {error_msg}
|
| 275 |
+
|
| 276 |
+
Database Schema:
|
| 277 |
+
{schema_preview}
|
| 278 |
+
|
| 279 |
+
Return the corrected SQL (keep it simple):"""
|
| 280 |
|
|
|
|
| 281 |
completion = self.client.chat.completions.create(
|
| 282 |
model=self.model,
|
| 283 |
messages=[
|
| 284 |
+
{"role": "system", "content": system_prompt},
|
| 285 |
+
{"role": "user", "content": user_prompt},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
],
|
| 287 |
+
temperature=0.1,
|
| 288 |
)
|
| 289 |
+
|
| 290 |
+
text = completion.choices[0].message.content
|
| 291 |
+
fixed_sql = text.strip() if text else ""
|
| 292 |
+
|
| 293 |
+
# Clean up accidental code fences
|
| 294 |
+
if fixed_sql.startswith("```sql"):
|
| 295 |
+
fixed_sql = fixed_sql[6:]
|
| 296 |
+
if fixed_sql.startswith("```"):
|
| 297 |
+
fixed_sql = fixed_sql[3:]
|
| 298 |
+
if fixed_sql.endswith("```"):
|
| 299 |
+
fixed_sql = fixed_sql[:-3]
|
| 300 |
+
|
| 301 |
+
fixed_sql = fixed_sql.strip()
|
| 302 |
+
fixed_sql = self._simplify_sql(fixed_sql)
|
| 303 |
+
|
| 304 |
usage = completion.usage
|
| 305 |
+
|
| 306 |
+
if usage:
|
| 307 |
+
prompt_tokens = usage.prompt_tokens
|
| 308 |
+
completion_tokens = usage.completion_tokens
|
| 309 |
+
cost = self._estimate_cost(usage)
|
| 310 |
+
return (fixed_sql, prompt_tokens, completion_tokens, cost)
|
| 311 |
+
else:
|
| 312 |
+
return (fixed_sql, 0, 0, 0.0)
|
| 313 |
+
|
| 314 |
+
def _estimate_cost(self, usage: Any) -> float:
|
| 315 |
+
"""Estimate cost based on token usage.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
usage: OpenAI usage object with token counts
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Estimated cost in USD
|
| 322 |
+
"""
|
| 323 |
+
if not usage:
|
| 324 |
+
return 0.0
|
| 325 |
+
|
| 326 |
+
# Pricing per 1K tokens (adjust based on model)
|
| 327 |
+
pricing = {
|
| 328 |
+
"gpt-4": {"input": 0.03, "output": 0.06},
|
| 329 |
+
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
| 330 |
+
"gpt-4o": {"input": 0.005, "output": 0.015},
|
| 331 |
+
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
|
| 332 |
+
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"])
|
| 336 |
+
|
| 337 |
+
input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"]
|
| 338 |
+
output_cost = (usage.completion_tokens / 1000) * model_pricing["output"]
|
| 339 |
+
|
| 340 |
+
return input_cost + output_cost
|
| 341 |
+
|
| 342 |
+
def clarify(
|
| 343 |
+
self,
|
| 344 |
+
*,
|
| 345 |
+
user_query: str,
|
| 346 |
+
schema_preview: str,
|
| 347 |
+
questions: List[str],
|
| 348 |
+
) -> Tuple[str, int, int, float]:
|
| 349 |
+
"""Clarify ambiguities in the user query.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
user_query: The user's natural language question
|
| 353 |
+
schema_preview: Database schema information
|
| 354 |
+
questions: List of clarification questions
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Tuple of (answers, prompt_tokens, completion_tokens, cost)
|
| 358 |
+
"""
|
| 359 |
+
system_prompt = """You are a helpful assistant that clarifies SQL query requirements.
|
| 360 |
+
Answer the questions clearly and concisely based on the user's query and database schema."""
|
| 361 |
+
|
| 362 |
+
user_prompt = f"""User Query: {user_query}
|
| 363 |
+
|
| 364 |
+
Database Schema:
|
| 365 |
+
{schema_preview}
|
| 366 |
+
|
| 367 |
+
Please answer these clarification questions:
|
| 368 |
+
{chr(10).join(f"{i + 1}. {q}" for i, q in enumerate(questions))}"""
|
| 369 |
+
|
| 370 |
+
completion = self.client.chat.completions.create(
|
| 371 |
+
model=self.model,
|
| 372 |
+
messages=[
|
| 373 |
+
{"role": "system", "content": system_prompt},
|
| 374 |
+
{"role": "user", "content": user_prompt},
|
| 375 |
+
],
|
| 376 |
+
temperature=0.3,
|
| 377 |
)
|
| 378 |
|
| 379 |
+
answers = completion.choices[0].message.content or ""
|
| 380 |
+
usage = completion.usage
|
| 381 |
+
|
| 382 |
+
if usage:
|
| 383 |
+
prompt_tokens = usage.prompt_tokens
|
| 384 |
+
completion_tokens = usage.completion_tokens
|
| 385 |
+
cost = self._estimate_cost(usage)
|
| 386 |
+
return (answers, prompt_tokens, completion_tokens, cost)
|
| 387 |
+
else:
|
| 388 |
+
return (answers, 0, 0, 0.0)
|
benchmarks/evaluate_spider_pro.py
CHANGED
|
@@ -1,490 +1,446 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
1) Single-DB demo mode (default)
|
| 6 |
-
- Runs a list of questions against one SQLite DB
|
| 7 |
-
- Reports latency/ok (no EM/SM/ExecAcc because there's no gold SQL)
|
| 8 |
-
|
| 9 |
-
2) Spider mode (--spider)
|
| 10 |
-
- Loads a subset of the Spider dataset via SPIDER_ROOT
|
| 11 |
-
- For each item, builds a per-DB pipeline and computes:
|
| 12 |
-
* EM (exact SQL string match, case-insensitive)
|
| 13 |
-
* SM (structural match via sqlglot AST)
|
| 14 |
-
* ExecAcc (result equivalence by executing gold vs. predicted SQL)
|
| 15 |
-
- Also logs latency, (optional) traces, and aggregates a summary
|
| 16 |
-
|
| 17 |
-
Works with:
|
| 18 |
-
- Real LLM (OPENAI_API_KEY set)
|
| 19 |
-
- Stub mode (PYTEST_CURRENT_TEST=1) for zero-cost offline runs
|
| 20 |
-
|
| 21 |
-
Outputs:
|
| 22 |
-
benchmarks/results_pro/<timestamp>/
|
| 23 |
-
- eval.jsonl # per-sample rows
|
| 24 |
-
- summary.json # aggregate metrics
|
| 25 |
-
- results.csv # human-friendly table
|
| 26 |
-
|
| 27 |
-
Examples:
|
| 28 |
-
# Demo (single DB), stub mode
|
| 29 |
-
PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
|
| 30 |
-
python benchmarks/evaluate_spider_pro.py --db-path demo.db
|
| 31 |
-
|
| 32 |
-
# Spider subset (20 items), stub mode
|
| 33 |
-
export SPIDER_ROOT=$PWD/data/spider
|
| 34 |
-
PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
|
| 35 |
-
python benchmarks/evaluate_spider_pro.py --spider --split dev --limit 20
|
| 36 |
"""
|
| 37 |
|
| 38 |
from __future__ import annotations
|
| 39 |
|
| 40 |
import argparse
|
| 41 |
-
import csv
|
| 42 |
import json
|
| 43 |
-
import
|
|
|
|
| 44 |
import time
|
|
|
|
|
|
|
| 45 |
from pathlib import Path
|
| 46 |
-
from typing import Any, Dict, List,
|
| 47 |
-
|
| 48 |
-
import sqlglot
|
| 49 |
-
from sqlglot.errors import ParseError
|
| 50 |
|
| 51 |
from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
|
| 52 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
try:
|
| 56 |
-
from benchmarks.spider_loader import load_spider_sqlite, open_readonly_connection
|
| 57 |
-
except Exception:
|
| 58 |
-
load_spider_sqlite = None # type: ignore[assignment]
|
| 59 |
-
open_readonly_connection = None # type: ignore[assignment]
|
| 60 |
-
|
| 61 |
-
# Resolve repo root and default config path relative to this file (not CWD)
|
| 62 |
-
THIS_DIR = Path(__file__).resolve().parent # .../benchmarks
|
| 63 |
-
REPO_ROOT = THIS_DIR.parent # repo root
|
| 64 |
-
CONFIG_PATH = str(REPO_ROOT / "configs" / "sqlite_pipeline.yaml")
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Default demo questions for single-DB mode
|
| 68 |
-
DEFAULT_DATASET: List[str] = [
|
| 69 |
-
"list all customers",
|
| 70 |
-
"show total invoices per country",
|
| 71 |
-
"top 3 albums by total sales",
|
| 72 |
-
"artists with more than 3 albums",
|
| 73 |
-
"number of employees per city",
|
| 74 |
-
]
|
| 75 |
-
|
| 76 |
-
RESULT_ROOT = Path("benchmarks") / "results_pro"
|
| 77 |
TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
|
| 78 |
RESULT_DIR = RESULT_ROOT / TIMESTAMP
|
| 79 |
|
| 80 |
|
| 81 |
-
#
|
| 82 |
|
| 83 |
|
| 84 |
-
def
|
| 85 |
-
"""
|
| 86 |
-
|
|
|
|
| 87 |
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
"""
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
getattr(pipeline_obj, "executor", None),
|
| 94 |
-
getattr(pipeline_obj, "adapter", None),
|
| 95 |
-
):
|
| 96 |
-
if c and hasattr(c, "derive_schema_preview"):
|
| 97 |
-
return c.derive_schema_preview() # type: ignore[no-any-return]
|
| 98 |
-
except Exception:
|
| 99 |
-
pass
|
| 100 |
-
return None
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
|
| 104 |
-
"""Normalize pipeline trace into a list of dicts for logging/export."""
|
| 105 |
-
out: List[Dict[str, Any]] = []
|
| 106 |
-
if not isinstance(trace_obj, list):
|
| 107 |
-
return out
|
| 108 |
-
for t in trace_obj:
|
| 109 |
-
if isinstance(t, dict):
|
| 110 |
-
stage = t.get("stage", "?")
|
| 111 |
-
ms = t.get("duration_ms", 0)
|
| 112 |
-
else:
|
| 113 |
-
stage = getattr(t, "stage", "?")
|
| 114 |
-
ms = getattr(t, "duration_ms", 0)
|
| 115 |
-
try:
|
| 116 |
-
out.append({"stage": str(stage), "ms": int(ms)})
|
| 117 |
-
except Exception:
|
| 118 |
-
out.append({"stage": str(stage), "ms": 0})
|
| 119 |
-
return out
|
| 120 |
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
except ParseError:
|
| 126 |
-
return None
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def _structural_match(pred: str, gold: str) -> bool:
|
| 130 |
-
"""AST-level equality via sqlglot; returns False if either side can't be parsed."""
|
| 131 |
-
a, b = _parse_sql(pred), _parse_sql(gold)
|
| 132 |
-
return (a == b) if (a is not None and b is not None) else False
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
def _load_dataset_from_file(path: Optional[str]) -> List[str]:
|
| 136 |
-
"""Load questions from a JSON file: list[str] or list[{question: str}]."""
|
| 137 |
-
if not path:
|
| 138 |
-
return DEFAULT_DATASET
|
| 139 |
-
p = Path(path)
|
| 140 |
-
if not p.exists():
|
| 141 |
-
raise FileNotFoundError(f"dataset file not found: {p}")
|
| 142 |
-
data = json.loads(p.read_text(encoding="utf-8"))
|
| 143 |
-
if isinstance(data, list):
|
| 144 |
-
if all(isinstance(x, str) for x in data):
|
| 145 |
-
return list(data)
|
| 146 |
-
if all(isinstance(x, dict) and "question" in x for x in data):
|
| 147 |
-
return [str(x["question"]) for x in data]
|
| 148 |
-
raise ValueError(
|
| 149 |
-
"Dataset file must be a JSON array of strings or objects with 'question' field."
|
| 150 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
-
def
|
| 154 |
-
"""
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
"""
|
| 158 |
-
sql_pred: Optional[str] = getattr(result, "sql", None)
|
| 159 |
-
if not sql_pred:
|
| 160 |
-
data = getattr(result, "data", None)
|
| 161 |
-
if data is not None:
|
| 162 |
-
sql_pred = getattr(data, "sql", None)
|
| 163 |
-
return (sql_pred or "").strip()
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
RESULT_DIR.mkdir(parents=True, exist_ok=True)
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
for r in rows:
|
| 173 |
-
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 174 |
-
|
| 175 |
-
with (RESULT_DIR / "summary.json").open("w", encoding="utf-8") as f:
|
| 176 |
-
json.dump(summary, f, indent=2)
|
| 177 |
-
|
| 178 |
-
csv_path = RESULT_DIR / "results.csv"
|
| 179 |
-
# For pro, include pro columns when present (Spider mode)
|
| 180 |
-
fieldnames = [
|
| 181 |
-
"source",
|
| 182 |
-
"db_id",
|
| 183 |
-
"query",
|
| 184 |
-
"em",
|
| 185 |
-
"sm",
|
| 186 |
-
"exec_acc",
|
| 187 |
-
"ok",
|
| 188 |
-
"latency_ms",
|
| 189 |
-
]
|
| 190 |
-
with csv_path.open("w", newline="", encoding="utf-8") as f:
|
| 191 |
-
wr = csv.DictWriter(f, fieldnames=fieldnames)
|
| 192 |
-
wr.writeheader()
|
| 193 |
-
for r in rows:
|
| 194 |
-
wr.writerow(
|
| 195 |
-
{
|
| 196 |
-
"source": r.get("source", "demo"),
|
| 197 |
-
"db_id": r.get("db_id", ""),
|
| 198 |
-
"query": r.get("query", ""),
|
| 199 |
-
"em": "✅" if r.get("em") else "❌" if "em" in r else "",
|
| 200 |
-
"sm": "✅" if r.get("sm") else "❌" if "sm" in r else "",
|
| 201 |
-
"exec_acc": "✅"
|
| 202 |
-
if r.get("exec_acc")
|
| 203 |
-
else "❌"
|
| 204 |
-
if "exec_acc" in r
|
| 205 |
-
else "",
|
| 206 |
-
"ok": "✅" if r.get("ok") else "❌",
|
| 207 |
-
"latency_ms": int(r.get("latency_ms", 0)),
|
| 208 |
-
}
|
| 209 |
-
)
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
f"- {jsonl_path}\n- {RESULT_DIR / 'summary.json'}\n- {csv_path}\n"
|
| 214 |
-
f"📊 Avg latency: {summary.get('avg_latency_ms', 0.0)} ms "
|
| 215 |
-
f"| EM: {summary.get('EM', 0.0):.3f} "
|
| 216 |
-
f"| SM: {summary.get('SM', 0.0):.3f} "
|
| 217 |
-
f"| ExecAcc: {summary.get('ExecAcc', 0.0):.3f} "
|
| 218 |
-
f"| Success: {summary.get('success_rate', 0.0):.0%}\n"
|
| 219 |
-
)
|
| 220 |
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
#
|
|
|
|
| 223 |
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
-
|
| 226 |
-
"""
|
| 227 |
-
Single-DB demo mode.
|
| 228 |
-
Only latency/ok is reported (no EM/SM/ExecAcc, because we don't have gold SQL).
|
| 229 |
-
"""
|
| 230 |
-
adapter = SQLiteAdapter(str(db_path))
|
| 231 |
-
pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
|
| 232 |
|
| 233 |
-
schema_preview = _derive_schema_preview_safe(pipeline)
|
| 234 |
-
if schema_preview:
|
| 235 |
-
print("📄 Derived schema preview ✓")
|
| 236 |
-
else:
|
| 237 |
-
print("ℹ️ No schema preview (adapter does not expose it or not needed)")
|
| 238 |
|
| 239 |
-
|
| 240 |
-
for q in questions:
|
| 241 |
-
print(f"\n🧠 Query: {q}")
|
| 242 |
-
t0 = time.perf_counter()
|
| 243 |
-
try:
|
| 244 |
-
result = pipeline.run(user_query=q, schema_preview=schema_preview or "")
|
| 245 |
-
latency_ms = _int_ms(t0) or 1 # clamp to 1ms for nicer CSV in stub mode
|
| 246 |
-
stages = _to_stage_list(
|
| 247 |
-
getattr(result, "traces", getattr(result, "trace", []))
|
| 248 |
-
)
|
| 249 |
-
rows.append(
|
| 250 |
-
{
|
| 251 |
-
"source": "demo",
|
| 252 |
-
"db_id": Path(db_path).stem,
|
| 253 |
-
"query": q,
|
| 254 |
-
"ok": bool(getattr(result, "ok", True)),
|
| 255 |
-
"latency_ms": latency_ms,
|
| 256 |
-
"trace": stages,
|
| 257 |
-
"error": None,
|
| 258 |
-
}
|
| 259 |
-
)
|
| 260 |
-
print(f"✅ Success ({latency_ms} ms)")
|
| 261 |
-
except Exception as exc:
|
| 262 |
-
latency_ms = _int_ms(t0) or 1
|
| 263 |
-
rows.append(
|
| 264 |
-
{
|
| 265 |
-
"source": "demo",
|
| 266 |
-
"db_id": Path(db_path).stem,
|
| 267 |
-
"query": q,
|
| 268 |
-
"ok": False,
|
| 269 |
-
"latency_ms": latency_ms,
|
| 270 |
-
"trace": [],
|
| 271 |
-
"error": str(exc),
|
| 272 |
-
}
|
| 273 |
-
)
|
| 274 |
-
print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
|
| 275 |
|
| 276 |
-
success_rate = (
|
| 277 |
-
(sum(1 for r in rows if r.get("ok")) / max(len(rows), 1)) if rows else 0.0
|
| 278 |
-
)
|
| 279 |
-
avg_latency = (
|
| 280 |
-
round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(len(rows), 1), 1)
|
| 281 |
-
if rows
|
| 282 |
-
else 0.0
|
| 283 |
-
)
|
| 284 |
-
summary = {
|
| 285 |
-
"mode": "single-db",
|
| 286 |
-
"db_path": str(db_path),
|
| 287 |
-
"config": config_path,
|
| 288 |
-
"provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
|
| 289 |
-
"total": len(rows),
|
| 290 |
-
"EM": 0.0,
|
| 291 |
-
"SM": 0.0,
|
| 292 |
-
"ExecAcc": 0.0, # not applicable in demo
|
| 293 |
-
"success_rate": success_rate,
|
| 294 |
-
"avg_latency_ms": avg_latency,
|
| 295 |
-
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 296 |
-
}
|
| 297 |
-
_save_outputs(rows, summary)
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
"""
|
| 305 |
-
if load_spider_sqlite is None or open_readonly_connection is None:
|
| 306 |
-
raise RuntimeError(
|
| 307 |
-
"Spider utilities are not available. Ensure benchmarks/spider_loader.py exists."
|
| 308 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
|
|
|
| 312 |
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
|
| 323 |
-
|
| 324 |
-
conn = open_readonly_connection(ex.db_path)
|
| 325 |
|
| 326 |
-
t0 = time.perf_counter()
|
| 327 |
-
try:
|
| 328 |
-
result = pipeline.run(
|
| 329 |
-
user_query=ex.question, schema_preview=schema_preview or ""
|
| 330 |
-
)
|
| 331 |
-
latency_ms = _int_ms(t0) or 1
|
| 332 |
-
stages = _to_stage_list(
|
| 333 |
-
getattr(result, "traces", getattr(result, "trace", []))
|
| 334 |
-
)
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
gold_sql = ex.gold_sql.strip()
|
| 341 |
-
em = (sql_pred.lower() == gold_sql.lower()) if sql_pred else False
|
| 342 |
-
sm = _structural_match(sql_pred, gold_sql) if sql_pred else False
|
| 343 |
-
|
| 344 |
-
try:
|
| 345 |
-
gold_exec = conn.execute(gold_sql).fetchall()
|
| 346 |
-
except Exception:
|
| 347 |
-
gold_exec = []
|
| 348 |
-
try:
|
| 349 |
-
pred_exec = conn.execute(sql_pred).fetchall() if sql_pred else []
|
| 350 |
-
except Exception:
|
| 351 |
-
pred_exec = []
|
| 352 |
-
exec_acc = gold_exec == pred_exec
|
| 353 |
-
|
| 354 |
-
rows.append(
|
| 355 |
-
{
|
| 356 |
-
"source": "spider",
|
| 357 |
-
"db_id": ex.db_id,
|
| 358 |
-
"query": ex.question,
|
| 359 |
-
"sql_pred": sql_pred,
|
| 360 |
-
"sql_gold": gold_sql,
|
| 361 |
-
"em": em,
|
| 362 |
-
"sm": sm,
|
| 363 |
-
"exec_acc": exec_acc,
|
| 364 |
-
"ok": bool(getattr(result, "ok", True)),
|
| 365 |
-
"latency_ms": latency_ms,
|
| 366 |
-
"trace": stages,
|
| 367 |
-
"error": None,
|
| 368 |
-
}
|
| 369 |
-
)
|
| 370 |
-
print(f"✅ OK | EM={em} | SM={sm} | Exec={exec_acc} | {latency_ms} ms")
|
| 371 |
-
except Exception as exc:
|
| 372 |
-
latency_ms = _int_ms(t0) or 1
|
| 373 |
-
rows.append(
|
| 374 |
-
{
|
| 375 |
-
"source": "spider",
|
| 376 |
-
"db_id": ex.db_id,
|
| 377 |
-
"query": ex.question,
|
| 378 |
-
"sql_pred": None,
|
| 379 |
-
"sql_gold": ex.gold_sql,
|
| 380 |
-
"em": False,
|
| 381 |
-
"sm": False,
|
| 382 |
-
"exec_acc": False,
|
| 383 |
-
"ok": False,
|
| 384 |
-
"latency_ms": latency_ms,
|
| 385 |
-
"trace": [],
|
| 386 |
-
"error": str(exc),
|
| 387 |
-
}
|
| 388 |
-
)
|
| 389 |
-
print(f"❌ Fail: {exc!s} ({latency_ms} ms)")
|
| 390 |
-
finally:
|
| 391 |
-
try:
|
| 392 |
-
conn.close()
|
| 393 |
-
except Exception:
|
| 394 |
-
pass
|
| 395 |
-
|
| 396 |
-
# Aggregate pro metrics
|
| 397 |
-
total = len(rows)
|
| 398 |
-
em_rate = (sum(1 for r in rows if r.get("em")) / max(total, 1)) if rows else 0.0
|
| 399 |
-
sm_rate = (sum(1 for r in rows if r.get("sm")) / max(total, 1)) if rows else 0.0
|
| 400 |
-
exec_rate = (
|
| 401 |
-
(sum(1 for r in rows if r.get("exec_acc")) / max(total, 1)) if rows else 0.0
|
| 402 |
-
)
|
| 403 |
-
success_rate = (
|
| 404 |
-
(sum(1 for r in rows if r.get("ok")) / max(total, 1)) if rows else 0.0
|
| 405 |
-
)
|
| 406 |
-
avg_latency = (
|
| 407 |
-
round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(total, 1), 1)
|
| 408 |
-
if rows
|
| 409 |
-
else 0.0
|
| 410 |
-
)
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
"
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
"EM": round(em_rate, 3),
|
| 421 |
-
"SM": round(sm_rate, 3),
|
| 422 |
-
"ExecAcc": round(exec_rate, 3),
|
| 423 |
-
"success_rate": success_rate,
|
| 424 |
-
"avg_latency_ms": avg_latency,
|
| 425 |
-
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 426 |
-
}
|
| 427 |
-
_save_outputs(rows, summary)
|
| 428 |
|
|
|
|
|
|
|
| 429 |
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
|
|
|
| 432 |
|
| 433 |
-
|
| 434 |
-
ap = argparse.ArgumentParser()
|
| 435 |
-
ap.add_argument(
|
| 436 |
-
"--spider",
|
| 437 |
-
action="store_true",
|
| 438 |
-
help="Enable Spider mode (reads from SPIDER_ROOT; ignores --db-path).",
|
| 439 |
-
)
|
| 440 |
-
ap.add_argument(
|
| 441 |
-
"--split",
|
| 442 |
-
type=str,
|
| 443 |
-
default="dev",
|
| 444 |
-
choices=["dev", "train"],
|
| 445 |
-
help="Spider split to use (default: dev).",
|
| 446 |
-
)
|
| 447 |
-
ap.add_argument(
|
| 448 |
-
"--limit",
|
| 449 |
-
type=int,
|
| 450 |
-
default=20,
|
| 451 |
-
help="Number of Spider items to evaluate (default: 20).",
|
| 452 |
-
)
|
| 453 |
|
| 454 |
-
ap.add_argument(
|
| 455 |
-
"--db-path",
|
| 456 |
-
type=str,
|
| 457 |
-
default="demo.db",
|
| 458 |
-
help="Path to SQLite database file (single-DB mode).",
|
| 459 |
-
)
|
| 460 |
-
ap.add_argument(
|
| 461 |
-
"--dataset-file",
|
| 462 |
-
type=str,
|
| 463 |
-
default=None,
|
| 464 |
-
help="Optional JSON file with questions (single-DB mode).",
|
| 465 |
-
)
|
| 466 |
-
ap.add_argument(
|
| 467 |
-
"--config",
|
| 468 |
-
type=str,
|
| 469 |
-
default=CONFIG_PATH,
|
| 470 |
-
help=f"Pipeline YAML config (default: {CONFIG_PATH})",
|
| 471 |
-
)
|
| 472 |
-
args = ap.parse_args()
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
)
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
|
| 489 |
if __name__ == "__main__":
|
|
|
|
| 490 |
main()
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Enhanced Spider benchmark evaluator for NL2SQL pipeline.
|
| 4 |
+
No external dependencies - uses internal evaluation logic.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import argparse
|
|
|
|
| 10 |
import json
|
| 11 |
+
import re
|
| 12 |
+
import sqlite3
|
| 13 |
import time
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from datetime import datetime
|
| 16 |
from pathlib import Path
|
| 17 |
+
from typing import Any, Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
|
| 20 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 21 |
+
from benchmarks.spider_loader import load_spider_sqlite
|
| 22 |
+
|
| 23 |
+
# ==================== Configuration ====================
|
| 24 |
|
| 25 |
+
RESULT_ROOT = Path("benchmarks/results_pro")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
|
| 27 |
RESULT_DIR = RESULT_ROOT / TIMESTAMP
|
| 28 |
|
| 29 |
|
| 30 |
+
# ==================== SQL Processing ====================
|
| 31 |
|
| 32 |
|
| 33 |
+
def extract_clean_sql(text: str | None) -> str:
|
| 34 |
+
"""Safely extract a clean SQL string from input text possibly containing markdown fences or JSON."""
|
| 35 |
+
# Always initialize variable to empty string
|
| 36 |
+
sql = text or ""
|
| 37 |
|
| 38 |
+
# Remove markdown code fences
|
| 39 |
+
sql = re.sub(r"```(?:sql)?\s*\n?", "", sql, flags=re.IGNORECASE)
|
| 40 |
+
sql = re.sub(r"```\s*$", "", sql)
|
| 41 |
|
| 42 |
+
# Try JSON pattern like {"sql": "..."}
|
| 43 |
+
m_json = re.search(r'"sql"\s*:\s*"([^"]+)"', sql)
|
| 44 |
+
if m_json:
|
| 45 |
+
sql = m_json.group(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
# Clean escaped characters
|
| 48 |
+
sql = sql.replace('\\"', '"').replace("\\n", " ").replace("\\t", " ")
|
| 49 |
|
| 50 |
+
# Try to locate SQL statement keywords
|
| 51 |
+
m_sql = re.search(
|
| 52 |
+
r"\b(select|with|insert|update|delete)\b[\s\S]+", sql, re.IGNORECASE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
)
|
| 54 |
+
if m_sql:
|
| 55 |
+
sql = m_sql.group(0)
|
| 56 |
+
sql = re.sub(r"\s+", " ", sql).strip().rstrip(";")
|
| 57 |
+
return sql
|
| 58 |
|
| 59 |
|
| 60 |
+
def normalize_sql(sql: str) -> str:
|
| 61 |
+
"""Enhanced SQL normalization for better matching."""
|
| 62 |
+
if not sql:
|
| 63 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
sql = sql.strip().upper()
|
| 66 |
+
# Remove all whitespace variations
|
| 67 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 68 |
+
# Remove trailing semicolon
|
| 69 |
+
sql = sql.rstrip(";")
|
| 70 |
|
| 71 |
+
# Remove table prefixes (e.g., singer.name -> name)
|
| 72 |
+
sql = re.sub(r"\b\w+\.(\w+)\b", r"\1", sql)
|
|
|
|
| 73 |
|
| 74 |
+
# Remove AS aliases
|
| 75 |
+
sql = re.sub(r"\s+AS\s+\w+", "", sql, flags=re.IGNORECASE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# Remove DISTINCT if used with COUNT(*)
|
| 78 |
+
sql = re.sub(r"COUNT\s*\(\s*DISTINCT\s+", "COUNT(", sql)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
# Normalize COUNT variations
|
| 81 |
+
sql = re.sub(r"COUNT\s*\(\s*\w+\s*\)", "COUNT(*)", sql)
|
| 82 |
|
| 83 |
+
# Remove LIMIT at end
|
| 84 |
+
sql = re.sub(r"\s+LIMIT\s+\d+$", "", sql)
|
| 85 |
|
| 86 |
+
# Normalize quotes
|
| 87 |
+
sql = re.sub(r'"(\w+)"', r"\1", sql)
|
| 88 |
+
sql = re.sub(r"`(\w+)`", r"\1", sql)
|
| 89 |
|
| 90 |
+
return sql
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
# ==================== Schema Extraction ====================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
def get_database_schema(db_path: Path) -> Dict[str, Any]:
|
| 97 |
+
"""Extract complete schema from SQLite database."""
|
| 98 |
+
if not db_path.exists():
|
| 99 |
+
return {}
|
| 100 |
+
|
| 101 |
+
conn = sqlite3.connect(str(db_path))
|
| 102 |
+
cursor = conn.cursor()
|
| 103 |
+
|
| 104 |
+
schema: dict[str, Any] = {"tables": {}}
|
| 105 |
|
| 106 |
+
try:
|
| 107 |
+
# Get all tables
|
| 108 |
+
cursor.execute(
|
| 109 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
+
tables = cursor.fetchall()
|
| 112 |
+
|
| 113 |
+
for (table_name,) in tables:
|
| 114 |
+
# Get columns
|
| 115 |
+
cursor.execute(f"PRAGMA table_info('{table_name}')")
|
| 116 |
+
columns = cursor.fetchall()
|
| 117 |
+
|
| 118 |
+
col_info = []
|
| 119 |
+
for col in columns:
|
| 120 |
+
col_name = col[1]
|
| 121 |
+
col_type = col[2]
|
| 122 |
+
is_pk = col[5]
|
| 123 |
+
|
| 124 |
+
col_dict = {
|
| 125 |
+
"name": col_name,
|
| 126 |
+
"type": col_type,
|
| 127 |
+
"primary_key": bool(is_pk),
|
| 128 |
+
}
|
| 129 |
+
col_info.append(col_dict)
|
| 130 |
|
| 131 |
+
# Get foreign keys
|
| 132 |
+
cursor.execute(f"PRAGMA foreign_key_list('{table_name}')")
|
| 133 |
+
fks = cursor.fetchall()
|
| 134 |
|
| 135 |
+
fk_info = []
|
| 136 |
+
for fk in fks:
|
| 137 |
+
fk_info.append(
|
| 138 |
+
{
|
| 139 |
+
"column": fk[3],
|
| 140 |
+
"referenced_table": fk[2],
|
| 141 |
+
"referenced_column": fk[4],
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
|
| 145 |
+
schema["tables"][table_name] = {
|
| 146 |
+
"columns": col_info,
|
| 147 |
+
"foreign_keys": fk_info,
|
| 148 |
+
}
|
| 149 |
|
| 150 |
+
finally:
|
| 151 |
+
conn.close()
|
| 152 |
|
| 153 |
+
return schema
|
|
|
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
def format_schema_for_prompt(schema: Dict[str, Any]) -> str:
|
| 157 |
+
"""Format schema for LLM prompt."""
|
| 158 |
+
if not schema or not schema.get("tables"):
|
| 159 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
lines = []
|
| 162 |
+
for table_name, table_info in schema["tables"].items():
|
| 163 |
+
cols = []
|
| 164 |
+
for col in table_info["columns"]:
|
| 165 |
+
col_str = f"{col['name']} {col['type']}"
|
| 166 |
+
if col.get("primary_key"):
|
| 167 |
+
col_str += " PRIMARY KEY"
|
| 168 |
+
cols.append(col_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
lines.append(f"Table: {table_name}")
|
| 171 |
+
lines.append(f"Columns: {', '.join(cols)}")
|
| 172 |
|
| 173 |
+
if table_info.get("foreign_keys"):
|
| 174 |
+
fks = []
|
| 175 |
+
for fk in table_info["foreign_keys"]:
|
| 176 |
+
fks.append(
|
| 177 |
+
f"{fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}"
|
| 178 |
+
)
|
| 179 |
+
lines.append(f"Foreign Keys: {', '.join(fks)}")
|
| 180 |
|
| 181 |
+
lines.append("") # Empty line between tables
|
| 182 |
|
| 183 |
+
return "\n".join(lines).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
# ==================== SQL Evaluation ====================
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def execute_sql(db_path: Path, sql: str) -> Tuple[bool, List[Tuple]]:
|
| 190 |
+
"""Execute SQL and return success flag and results."""
|
| 191 |
+
if not sql:
|
| 192 |
+
return False, []
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
conn = sqlite3.connect(str(db_path))
|
| 196 |
+
cursor = conn.cursor()
|
| 197 |
+
cursor.execute(sql)
|
| 198 |
+
results = cursor.fetchall()
|
| 199 |
+
conn.close()
|
| 200 |
+
return True, results
|
| 201 |
+
except Exception:
|
| 202 |
+
return False, []
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def compare_sql_results(gold_results: List[Tuple], pred_results: List[Tuple]) -> bool:
|
| 206 |
+
"""Compare SQL execution results."""
|
| 207 |
+
if len(gold_results) != len(pred_results):
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
# Convert to sets for comparison (order independent)
|
| 211 |
+
gold_set = set(gold_results)
|
| 212 |
+
pred_set = set(pred_results)
|
| 213 |
+
|
| 214 |
+
return gold_set == pred_set
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def evaluate_sql_match(pred_sql: str, gold_sql: str, db_path: Path) -> Dict[str, float]:
|
| 218 |
+
"""Evaluate predicted SQL against gold SQL."""
|
| 219 |
+
metrics = {"exact_match": 0.0, "set_match": 0.0, "exec_accuracy": 0.0}
|
| 220 |
+
|
| 221 |
+
if not pred_sql:
|
| 222 |
+
return metrics
|
| 223 |
+
|
| 224 |
+
# Exact match
|
| 225 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_sql):
|
| 226 |
+
metrics["exact_match"] = 1.0
|
| 227 |
+
|
| 228 |
+
# Execution-based evaluation
|
| 229 |
+
gold_success, gold_results = execute_sql(db_path, gold_sql)
|
| 230 |
+
pred_success, pred_results = execute_sql(db_path, pred_sql)
|
| 231 |
+
|
| 232 |
+
if gold_success and pred_success:
|
| 233 |
+
# Set match (results match)
|
| 234 |
+
if compare_sql_results(gold_results, pred_results):
|
| 235 |
+
metrics["set_match"] = 1.0
|
| 236 |
+
metrics["exec_accuracy"] = 1.0
|
| 237 |
+
else:
|
| 238 |
+
# Partial credit for successful execution
|
| 239 |
+
metrics["exec_accuracy"] = 0.5
|
| 240 |
+
|
| 241 |
+
return metrics
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ==================== Pipeline Runner ====================
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@dataclass
|
| 248 |
+
class SpiderSample:
|
| 249 |
+
"""Spider dataset sample."""
|
| 250 |
+
|
| 251 |
+
question: str
|
| 252 |
+
db_id: str
|
| 253 |
+
db_path: Path
|
| 254 |
+
gold_sql: str
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def run_pipeline_on_sample(
|
| 258 |
+
pipeline: Any,
|
| 259 |
+
sample: SpiderSample,
|
| 260 |
+
schema_cache: Dict[str, str],
|
| 261 |
+
debug: bool = False,
|
| 262 |
+
) -> Dict[str, Any]:
|
| 263 |
+
"""Run NL2SQL pipeline on a single sample."""
|
| 264 |
+
|
| 265 |
+
# Get/cache schema
|
| 266 |
+
if sample.db_id not in schema_cache:
|
| 267 |
+
schema_dict = get_database_schema(sample.db_path)
|
| 268 |
+
schema_str = format_schema_for_prompt(schema_dict)
|
| 269 |
+
schema_cache[sample.db_id] = schema_str
|
| 270 |
+
if debug:
|
| 271 |
+
print(f" [schema] Loaded {len(schema_str)} chars for {sample.db_id}")
|
| 272 |
+
|
| 273 |
+
schema: str = schema_cache[sample.db_id]
|
| 274 |
+
|
| 275 |
+
# Run pipeline
|
| 276 |
+
try:
|
| 277 |
+
result = pipeline.run(user_query=sample.question, schema_preview=schema)
|
| 278 |
+
|
| 279 |
+
# Extract SQL from result
|
| 280 |
+
if hasattr(result, "sql") and result.sql:
|
| 281 |
+
pred_sql = extract_clean_sql(result.sql)
|
| 282 |
+
else:
|
| 283 |
+
# Try to extract from various fields
|
| 284 |
+
for attr in ["final_sql", "generated_sql", "answer"]:
|
| 285 |
+
if hasattr(result, attr):
|
| 286 |
+
val = getattr(result, attr)
|
| 287 |
+
if val:
|
| 288 |
+
pred_sql = extract_clean_sql(str(val))
|
| 289 |
+
if pred_sql:
|
| 290 |
+
break
|
| 291 |
+
else:
|
| 292 |
+
pred_sql = ""
|
| 293 |
+
|
| 294 |
+
return {
|
| 295 |
+
"ok": bool(getattr(result, "ok", True)),
|
| 296 |
+
"sql": pred_sql,
|
| 297 |
+
"raw_response": getattr(result, "sql", ""),
|
| 298 |
+
"traces": getattr(result, "traces", []),
|
| 299 |
+
"error": None,
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
except Exception as e:
|
| 303 |
+
if debug:
|
| 304 |
+
import traceback
|
| 305 |
+
|
| 306 |
+
traceback.print_exc()
|
| 307 |
+
return {
|
| 308 |
+
"ok": False,
|
| 309 |
+
"sql": "",
|
| 310 |
+
"raw_response": "",
|
| 311 |
+
"traces": [],
|
| 312 |
+
"error": str(e),
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ==================== Main Evaluation ====================
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def main():
|
| 320 |
+
parser = argparse.ArgumentParser(description="Evaluate NL2SQL on Spider")
|
| 321 |
+
parser.add_argument("--spider", action="store_true", help="Run Spider evaluation")
|
| 322 |
+
parser.add_argument("--split", default="dev", choices=["dev", "train"])
|
| 323 |
+
parser.add_argument("--limit", type=int, help="Limit number of samples")
|
| 324 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug output")
|
| 325 |
+
parser.add_argument("--config", default="configs/sqlite_pipeline.yaml")
|
| 326 |
+
|
| 327 |
+
args = parser.parse_args()
|
| 328 |
+
|
| 329 |
+
if not args.spider:
|
| 330 |
+
print("Please use --spider flag to run Spider evaluation")
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
# Load Spider samples
|
| 334 |
+
print(f"Loading Spider {args.split} split...")
|
| 335 |
+
samples = load_spider_sqlite(split=args.split, limit=args.limit)
|
| 336 |
+
|
| 337 |
+
if not samples:
|
| 338 |
+
print("❌ No samples loaded. Check SPIDER_ROOT environment variable.")
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
print(f"✔ Loaded {len(samples)} samples")
|
| 342 |
+
|
| 343 |
+
# Prepare results directory
|
| 344 |
+
RESULT_DIR.mkdir(parents=True, exist_ok=True)
|
| 345 |
+
|
| 346 |
+
# Initialize schema cache
|
| 347 |
+
schema_cache = {}
|
| 348 |
+
|
| 349 |
+
# Process each sample
|
| 350 |
+
results = []
|
| 351 |
+
for i, spider_item in enumerate(samples, 1):
|
| 352 |
+
# Convert to our sample format
|
| 353 |
+
sample = SpiderSample(
|
| 354 |
+
question=spider_item.question,
|
| 355 |
+
db_id=spider_item.db_id,
|
| 356 |
+
db_path=Path(spider_item.db_path),
|
| 357 |
+
gold_sql=spider_item.gold_sql,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
print(f"\n🧠 [{i}/{len(samples)}] [{sample.db_id}] {sample.question}")
|
| 361 |
+
|
| 362 |
+
# Create adapter and pipeline for this database
|
| 363 |
+
adapter = SQLiteAdapter(sample.db_path)
|
| 364 |
+
pipeline = pipeline_from_config_with_adapter(args.config, adapter=adapter)
|
| 365 |
+
|
| 366 |
+
# Run pipeline
|
| 367 |
+
t0 = time.perf_counter()
|
| 368 |
+
result = run_pipeline_on_sample(pipeline, sample, schema_cache, args.debug)
|
| 369 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 370 |
+
|
| 371 |
+
# Evaluate
|
| 372 |
+
metrics = evaluate_sql_match(result["sql"], sample.gold_sql, sample.db_path)
|
| 373 |
+
|
| 374 |
+
# Store result
|
| 375 |
+
eval_result = {
|
| 376 |
+
"source": "spider",
|
| 377 |
+
"db_id": sample.db_id,
|
| 378 |
+
"query": sample.question,
|
| 379 |
+
"gold_sql": sample.gold_sql,
|
| 380 |
+
"pred_sql": result["sql"],
|
| 381 |
+
"ok": result["ok"],
|
| 382 |
+
"latency_ms": latency_ms,
|
| 383 |
+
"em": metrics["exact_match"],
|
| 384 |
+
"sm": metrics["set_match"],
|
| 385 |
+
"exec_acc": metrics["exec_accuracy"],
|
| 386 |
+
"error": result.get("error"),
|
| 387 |
+
"trace": result.get("traces", []),
|
| 388 |
+
}
|
| 389 |
+
results.append(eval_result)
|
| 390 |
+
|
| 391 |
+
# Debug output
|
| 392 |
+
if args.debug:
|
| 393 |
+
status = "✅" if result["ok"] and metrics["exact_match"] == 1 else "⚠️"
|
| 394 |
+
print(
|
| 395 |
+
f"{status} ({latency_ms} ms) | EM={metrics['exact_match']:.0f} SM={metrics['set_match']:.0f} ExecAcc={metrics['exec_accuracy']:.1f}"
|
| 396 |
)
|
| 397 |
+
if metrics["exact_match"] < 1:
|
| 398 |
+
print(f" gold: {sample.gold_sql[:100]}")
|
| 399 |
+
print(f" pred: {result['sql'][:100] if result['sql'] else 'EMPTY'}")
|
| 400 |
+
|
| 401 |
+
# Calculate aggregates
|
| 402 |
+
total = len(results)
|
| 403 |
+
successful = sum(1 for r in results if r["ok"])
|
| 404 |
+
avg_em = sum(r["em"] for r in results) / total if total > 0 else 0
|
| 405 |
+
avg_sm = sum(r["sm"] for r in results) / total if total > 0 else 0
|
| 406 |
+
avg_ea = sum(r["exec_acc"] for r in results) / total if total > 0 else 0
|
| 407 |
+
avg_latency = sum(r["latency_ms"] for r in results) / total if total > 0 else 0
|
| 408 |
+
|
| 409 |
+
# Save results
|
| 410 |
+
eval_jsonl = RESULT_DIR / "eval.jsonl"
|
| 411 |
+
with open(eval_jsonl, "w") as f:
|
| 412 |
+
for r in results:
|
| 413 |
+
json.dump(r, f, ensure_ascii=False)
|
| 414 |
+
f.write("\n")
|
| 415 |
+
|
| 416 |
+
summary = {
|
| 417 |
+
"timestamp": datetime.now().isoformat(timespec="seconds"),
|
| 418 |
+
"total": total,
|
| 419 |
+
"success": successful,
|
| 420 |
+
"success_rate": round(successful / total, 3) if total else 0,
|
| 421 |
+
"avg_latency_ms": round(avg_latency, 1),
|
| 422 |
+
"EM": round(avg_em, 3),
|
| 423 |
+
"SM": round(avg_sm, 3),
|
| 424 |
+
"ExecAcc": round(avg_ea, 3),
|
| 425 |
+
"split": args.split,
|
| 426 |
+
"config": args.config,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
(RESULT_DIR / "summary.json").write_text(
|
| 430 |
+
json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
print("\n================== Evaluation Summary ==================")
|
| 434 |
+
print(f"Total samples: {total}")
|
| 435 |
+
print(f"Successful runs: {successful} ({summary['success_rate'] * 100:.1f}%)")
|
| 436 |
+
print(f"Avg EM: {summary['EM']}")
|
| 437 |
+
print(f"Avg SM: {summary['SM']}")
|
| 438 |
+
print(f"Avg ExecAcc: {summary['ExecAcc']}")
|
| 439 |
+
print(f"Avg Latency: {summary['avg_latency_ms']} ms")
|
| 440 |
+
print(f"Results saved to {RESULT_DIR}")
|
| 441 |
+
print("========================================================")
|
| 442 |
|
| 443 |
|
| 444 |
if __name__ == "__main__":
|
| 445 |
+
RESULT_DIR.mkdir(parents=True, exist_ok=True)
|
| 446 |
main()
|
benchmarks/results_pro/20251108-123204/eval.jsonl
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
{"source": "demo", "db_id": "demo", "query": "list all customers", "ok": false, "latency_ms": 8406, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 3768}, {"stage": "generator", "ms": 1616}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 3}, {"stage": "repair", "ms": 1639}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1367}, {"stage": "safety", "ms": 3}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 2 |
-
{"source": "demo", "db_id": "demo", "query": "show total invoices per country", "ok": true, "latency_ms": 11003, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 5021}, {"stage": "generator", "ms": 1605}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1437}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 2929}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 3 |
-
{"source": "demo", "db_id": "demo", "query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}], "error": null}
|
| 4 |
-
{"source": "demo", "db_id": "demo", "query": "artists with more than 3 albums", "ok": false, "latency_ms": 14409, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 8377}, {"stage": "generator", "ms": 2525}, {"stage": "safety", "ms": 4}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1618}, {"stage": "safety", "ms": 4}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1874}, {"stage": "safety", "ms": 3}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 5 |
-
{"source": "demo", "db_id": "demo", "query": "number of employees per city", "ok": true, "latency_ms": 8938, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4402}, {"stage": "generator", "ms": 1846}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1397}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1283}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-123204/latency_per_stage.png
DELETED
|
Binary file (34.7 kB)
|
|
|
benchmarks/results_pro/20251108-123204/metrics_overview.png
DELETED
|
Binary file (22.7 kB)
|
|
|
benchmarks/results_pro/20251108-123204/results.csv
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
source,db_id,query,em,sm,exec_acc,ok,latency_ms
|
| 2 |
-
demo,demo,list all customers,,,,❌,8406
|
| 3 |
-
demo,demo,show total invoices per country,,,,✅,11003
|
| 4 |
-
demo,demo,top 3 albums by total sales,,,,✅,1
|
| 5 |
-
demo,demo,artists with more than 3 albums,,,,❌,14409
|
| 6 |
-
demo,demo,number of employees per city,,,,✅,8938
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-123204/summary.json
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"mode": "single-db",
|
| 3 |
-
"db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
|
| 4 |
-
"config": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/configs/sqlite_pipeline.yaml",
|
| 5 |
-
"provider_hint": "REAL",
|
| 6 |
-
"total": 5,
|
| 7 |
-
"EM": 0.0,
|
| 8 |
-
"SM": 0.0,
|
| 9 |
-
"ExecAcc": 0.0,
|
| 10 |
-
"success_rate": 0.6,
|
| 11 |
-
"avg_latency_ms": 8551.4,
|
| 12 |
-
"timestamp": "2025-11-08 12:32:47"
|
| 13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-124153/eval.jsonl
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
{"source": "demo", "db_id": "demo", "query": "list all customers", "ok": false, "latency_ms": 6756, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 2729}, {"stage": "generator", "ms": 1343}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 2}, {"stage": "repair", "ms": 911}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1763}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 2 |
-
{"source": "demo", "db_id": "demo", "query": "show total invoices per country", "ok": true, "latency_ms": 8901, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4799}, {"stage": "generator", "ms": 1075}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1092}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1924}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 3 |
-
{"source": "demo", "db_id": "demo", "query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}], "error": null}
|
| 4 |
-
{"source": "demo", "db_id": "demo", "query": "artists with more than 3 albums", "ok": false, "latency_ms": 12342, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4882}, {"stage": "generator", "ms": 2684}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 2630}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 2135}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 5 |
-
{"source": "demo", "db_id": "demo", "query": "number of employees per city", "ok": true, "latency_ms": 7547, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4083}, {"stage": "generator", "ms": 1269}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1149}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1035}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-124153/latency_per_stage.png
DELETED
|
Binary file (34.7 kB)
|
|
|
benchmarks/results_pro/20251108-124153/metrics_overview.png
DELETED
|
Binary file (22.7 kB)
|
|
|
benchmarks/results_pro/20251108-124153/results.csv
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
source,db_id,query,em,sm,exec_acc,ok,latency_ms
|
| 2 |
-
demo,demo,list all customers,,,,❌,6756
|
| 3 |
-
demo,demo,show total invoices per country,,,,✅,8901
|
| 4 |
-
demo,demo,top 3 albums by total sales,,,,✅,1
|
| 5 |
-
demo,demo,artists with more than 3 albums,,,,❌,12342
|
| 6 |
-
demo,demo,number of employees per city,,,,✅,7547
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-124153/summary.json
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"mode": "single-db",
|
| 3 |
-
"db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
|
| 4 |
-
"config": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/configs/sqlite_pipeline.yaml",
|
| 5 |
-
"provider_hint": "REAL",
|
| 6 |
-
"total": 5,
|
| 7 |
-
"EM": 0.0,
|
| 8 |
-
"SM": 0.0,
|
| 9 |
-
"ExecAcc": 0.0,
|
| 10 |
-
"success_rate": 0.6,
|
| 11 |
-
"avg_latency_ms": 7109.4,
|
| 12 |
-
"timestamp": "2025-11-08 12:42:29"
|
| 13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-125829/eval.jsonl
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
{"source": "demo", "db_id": "demo", "query": "list all customers", "ok": false, "latency_ms": 6652, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 2554}, {"stage": "generator", "ms": 1370}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 1}, {"stage": "repair", "ms": 1295}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "repair", "ms": 1426}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 2 |
-
{"source": "demo", "db_id": "demo", "query": "show total invoices per country", "ok": true, "latency_ms": 7375, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 3866}, {"stage": "generator", "ms": 1265}, {"stage": "safety", "ms": 4}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1126}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1106}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 3 |
-
{"source": "demo", "db_id": "demo", "query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}], "error": null}
|
| 4 |
-
{"source": "demo", "db_id": "demo", "query": "artists with more than 3 albums", "ok": false, "latency_ms": 8629, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4110}, {"stage": "generator", "ms": 1969}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1296}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1244}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
| 5 |
-
{"source": "demo", "db_id": "demo", "query": "number of employees per city", "ok": true, "latency_ms": 5630, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 2602}, {"stage": "generator", "ms": 1097}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1018}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 906}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-125829/latency_per_stage.png
DELETED
|
Binary file (22.4 kB)
|
|
|
benchmarks/results_pro/20251108-125829/metrics_overview.png
DELETED
|
Binary file (12.9 kB)
|
|
|
benchmarks/results_pro/20251108-125829/results.csv
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
source,db_id,query,em,sm,exec_acc,ok,latency_ms
|
| 2 |
-
demo,demo,list all customers,,,,❌,6652
|
| 3 |
-
demo,demo,show total invoices per country,,,,✅,7375
|
| 4 |
-
demo,demo,top 3 albums by total sales,,,,✅,1
|
| 5 |
-
demo,demo,artists with more than 3 albums,,,,❌,8629
|
| 6 |
-
demo,demo,number of employees per city,,,,✅,5630
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251108-125829/summary.json
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"mode": "single-db",
|
| 3 |
-
"db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
|
| 4 |
-
"config": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/configs/sqlite_pipeline.yaml",
|
| 5 |
-
"provider_hint": "REAL",
|
| 6 |
-
"total": 5,
|
| 7 |
-
"EM": 0.0,
|
| 8 |
-
"SM": 0.0,
|
| 9 |
-
"ExecAcc": 0.0,
|
| 10 |
-
"success_rate": 0.6,
|
| 11 |
-
"avg_latency_ms": 5657.4,
|
| 12 |
-
"timestamp": "2025-11-08 12:58:58"
|
| 13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/results_pro/20251109-092540/eval.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 9423, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6884, "summary": "ok", "notes": {"len_plan": 1313}, "token_in": 270, "token_out": 313, "cost_usd": 0.0002283}, {"stage": "generator", "duration_ms": 891, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 801, "token_out": 19, "cost_usd": 0.00013155}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 673, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 962, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 2 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 9382, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6936, "summary": "ok", "notes": {"len_plan": 1501}, "token_in": 271, "token_out": 351, "cost_usd": 0.00025124999999999995}, {"stage": "generator", "duration_ms": 1014, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 840, "token_out": 19, "cost_usd": 0.00013739999999999998}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 2, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 710, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 710, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 3 |
+
{"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}}]}
|
| 4 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 11380, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 7152, "summary": "ok", "notes": {"len_plan": 1281}, "token_in": 281, "token_out": 295, "cost_usd": 0.00021914999999999996}, {"stage": "generator", "duration_ms": 2189, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 794, "token_out": 37, "cost_usd": 0.0001413}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 954, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64}, "token_in": 325, "token_out": 21, "cost_usd": 6.135e-05}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1074, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64}, "token_in": 328, "token_out": 21, "cost_usd": 6.18e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 5 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 10894, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 7383, "summary": "ok", "notes": {"len_plan": 1579}, "token_in": 279, "token_out": 421, "cost_usd": 0.00029445}, {"stage": "generator", "duration_ms": 1242, "summary": "ok", "notes": {"rationale_len": 67}, "token_in": 918, "token_out": 42, "cost_usd": 0.00016289999999999998}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1078, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 333, "token_out": 24, "cost_usd": 6.435e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 3, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1173, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 337, "token_out": 28, "cost_usd": 6.735e-05}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
benchmarks/results_pro/20251109-092540/summary.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-11-09T09:26:21",
|
| 3 |
+
"total": 5,
|
| 4 |
+
"success": 5,
|
| 5 |
+
"success_rate": 1.0,
|
| 6 |
+
"avg_latency_ms": 8215.8,
|
| 7 |
+
"EM": 0.4,
|
| 8 |
+
"SM": 0.8,
|
| 9 |
+
"ExecAcc": 0.8,
|
| 10 |
+
"split": "dev",
|
| 11 |
+
"config": "configs/sqlite_pipeline.yaml"
|
| 12 |
+
}
|
benchmarks/results_pro/20251109-092823/eval.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 7982, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 5384, "summary": "ok", "notes": {"len_plan": 1287}, "token_in": 270, "token_out": 306, "cost_usd": 0.0002241}, {"stage": "generator", "duration_ms": 900, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 794, "token_out": 19, "cost_usd": 0.0001305}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 888, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 797, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 2 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 9717, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6881, "summary": "ok", "notes": {"len_plan": 1352}, "token_in": 271, "token_out": 319, "cost_usd": 0.00023204999999999998}, {"stage": "generator", "duration_ms": 1162, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 808, "token_out": 19, "cost_usd": 0.0001326}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 716, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 950, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 3 |
+
{"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}}]}
|
| 4 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 8523, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 5311, "summary": "ok", "notes": {"len_plan": 1449}, "token_in": 281, "token_out": 343, "cost_usd": 0.00024795}, {"stage": "generator", "duration_ms": 1306, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 842, "token_out": 37, "cost_usd": 0.00014849999999999998}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 996, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64}, "token_in": 325, "token_out": 21, "cost_usd": 6.135e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 900, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64}, "token_in": 328, "token_out": 21, "cost_usd": 6.18e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 5 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 12291, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 8346, "summary": "ok", "notes": {"len_plan": 1363}, "token_in": 279, "token_out": 334, "cost_usd": 0.00024225}, {"stage": "generator", "duration_ms": 1636, "summary": "ok", "notes": {"rationale_len": 87}, "token_in": 831, "token_out": 46, "cost_usd": 0.00015225}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 2, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1137, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 333, "token_out": 25, "cost_usd": 6.495e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 3, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1151, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 337, "token_out": 28, "cost_usd": 6.735e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
benchmarks/results_pro/20251109-092823/summary.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-11-09T09:29:01",
|
| 3 |
+
"total": 5,
|
| 4 |
+
"success": 5,
|
| 5 |
+
"success_rate": 1.0,
|
| 6 |
+
"avg_latency_ms": 7702.6,
|
| 7 |
+
"EM": 0.4,
|
| 8 |
+
"SM": 0.8,
|
| 9 |
+
"ExecAcc": 0.8,
|
| 10 |
+
"split": "dev",
|
| 11 |
+
"config": "configs/sqlite_pipeline.yaml"
|
| 12 |
+
}
|
benchmarks/results_pro/20251109-093743/eval.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 10480, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 8010, "summary": "ok", "notes": {"len_plan": 1445}, "token_in": 270, "token_out": 337, "cost_usd": 0.00024270000000000002}, {"stage": "generator", "duration_ms": 1029, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 825, "token_out": 19, "cost_usd": 0.00013514999999999998}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 678, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 750, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 2 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 10687, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6978, "summary": "ok", "notes": {"len_plan": 1512}, "token_in": 271, "token_out": 355, "cost_usd": 0.00025364999999999996}, {"stage": "generator", "duration_ms": 2192, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 844, "token_out": 19, "cost_usd": 0.000138}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 652, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 863, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 3 |
+
{"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}}]}
|
| 4 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 16736, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 13205, "summary": "ok", "notes": {"len_plan": 1758}, "token_in": 281, "token_out": 409, "cost_usd": 0.00028754999999999997}, {"stage": "generator", "duration_ms": 1537, "summary": "ok", "notes": {"rationale_len": 83}, "token_in": 908, "token_out": 37, "cost_usd": 0.0001584}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1019, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64}, "token_in": 325, "token_out": 21, "cost_usd": 6.135e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 968, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64}, "token_in": 328, "token_out": 21, "cost_usd": 6.18e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
| 5 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 12440, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 7973, "summary": "ok", "notes": {"len_plan": 1377}, "token_in": 279, "token_out": 345, "cost_usd": 0.00024884999999999995}, {"stage": "generator", "duration_ms": 1827, "summary": "ok", "notes": {"rationale_len": 94}, "token_in": 841, "token_out": 47, "cost_usd": 0.00015434999999999998}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1312, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 333, "token_out": 24, "cost_usd": 6.435e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 2, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1313, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 337, "token_out": 21, "cost_usd": 6.315e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
|
benchmarks/results_pro/20251109-093743/summary.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-11-09T09:38:33",
|
| 3 |
+
"total": 5,
|
| 4 |
+
"success": 5,
|
| 5 |
+
"success_rate": 1.0,
|
| 6 |
+
"avg_latency_ms": 10068.6,
|
| 7 |
+
"EM": 0.4,
|
| 8 |
+
"SM": 0.8,
|
| 9 |
+
"ExecAcc": 0.8,
|
| 10 |
+
"split": "dev",
|
| 11 |
+
"config": "configs/sqlite_pipeline.yaml"
|
| 12 |
+
}
|
nl2sql/pipeline.py
CHANGED
|
@@ -31,9 +31,8 @@ class FinalResult:
|
|
| 31 |
|
| 32 |
class Pipeline:
|
| 33 |
"""
|
| 34 |
-
NL2SQL Copilot pipeline
|
| 35 |
-
|
| 36 |
-
DI-ready: all dependencies are injected via __init__.
|
| 37 |
"""
|
| 38 |
|
| 39 |
def __init__(
|
|
@@ -54,22 +53,21 @@ class Pipeline:
|
|
| 54 |
self.executor = executor or NoOpExecutor()
|
| 55 |
self.verifier = verifier or NoOpVerifier()
|
| 56 |
self.repair = repair or NoOpRepair()
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
#
|
| 59 |
@staticmethod
|
| 60 |
def _trace_list(*stages: Optional[StageResult]) -> List[dict]:
|
| 61 |
-
"""Collect .trace objects (as dict) from StageResult items if present."""
|
| 62 |
traces: List[dict] = []
|
| 63 |
for s in stages:
|
| 64 |
if not s:
|
| 65 |
continue
|
| 66 |
t = getattr(s, "trace", None)
|
| 67 |
if t is not None:
|
| 68 |
-
# t is likely a dataclass – expose as plain dict for JSON safety
|
| 69 |
traces.append(getattr(t, "__dict__", t))
|
| 70 |
return traces
|
| 71 |
|
| 72 |
-
# ------------------------------------------------------------
|
| 73 |
@staticmethod
|
| 74 |
def _mk_trace(
|
| 75 |
stage: str,
|
|
@@ -77,7 +75,6 @@ class Pipeline:
|
|
| 77 |
summary: str,
|
| 78 |
notes: Optional[Dict[str, Any]] = None,
|
| 79 |
) -> dict:
|
| 80 |
-
"""Create a normalized trace dict (internal: duration may be float)."""
|
| 81 |
return {
|
| 82 |
"stage": stage,
|
| 83 |
"duration_ms": float(duration_ms),
|
|
@@ -87,11 +84,6 @@ class Pipeline:
|
|
| 87 |
|
| 88 |
@staticmethod
|
| 89 |
def _normalize_traces(traces: List[dict]) -> List[dict]:
|
| 90 |
-
"""
|
| 91 |
-
Normalize trace list for API/UI:
|
| 92 |
-
- coerce duration_ms to int
|
| 93 |
-
- ensure `summary` exists (fallback to a minimal one)
|
| 94 |
-
"""
|
| 95 |
norm: List[dict] = []
|
| 96 |
for t in traces:
|
| 97 |
stage = str(t.get("stage", "unknown"))
|
|
@@ -100,37 +92,24 @@ class Pipeline:
|
|
| 100 |
dur_int = int(round(float(dur)))
|
| 101 |
except Exception:
|
| 102 |
dur_int = 0
|
| 103 |
-
summary = t.get("summary")
|
| 104 |
-
if not summary:
|
| 105 |
-
# fallback summary if not provided by stage
|
| 106 |
-
notes = t.get("notes") or {}
|
| 107 |
-
failed = bool(notes.get("error") or notes.get("errors"))
|
| 108 |
-
summary = "failed" if failed else "ok"
|
| 109 |
notes = t.get("notes") or {}
|
| 110 |
-
|
|
|
|
|
|
|
| 111 |
payload = {
|
| 112 |
"stage": stage,
|
| 113 |
"duration_ms": dur_int,
|
| 114 |
"summary": summary,
|
| 115 |
"notes": notes,
|
| 116 |
}
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if "token_out" in t:
|
| 121 |
-
payload["token_out"] = t["token_out"]
|
| 122 |
-
if "cost_usd" in t:
|
| 123 |
-
payload["cost_usd"] = t["cost_usd"]
|
| 124 |
norm.append(payload)
|
| 125 |
return norm
|
| 126 |
|
| 127 |
-
# ------------------------------------------------------------
|
| 128 |
@staticmethod
|
| 129 |
def _safe_stage(fn, **kwargs) -> StageResult:
|
| 130 |
-
"""
|
| 131 |
-
Run a stage safely; if it throws, return a StageResult(ok=False, error=[...]).
|
| 132 |
-
If fn returns a non-StageResult (e.g., dict), coerce to StageResult(ok=True, data=...).
|
| 133 |
-
"""
|
| 134 |
try:
|
| 135 |
r = fn(**kwargs)
|
| 136 |
if isinstance(r, StageResult):
|
|
@@ -140,7 +119,7 @@ class Pipeline:
|
|
| 140 |
tb = traceback.format_exc()
|
| 141 |
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
| 142 |
|
| 143 |
-
#
|
| 144 |
def run(
|
| 145 |
self,
|
| 146 |
*,
|
|
@@ -152,7 +131,6 @@ class Pipeline:
|
|
| 152 |
traces: List[dict] = []
|
| 153 |
details: List[str] = []
|
| 154 |
|
| 155 |
-
# Always push a normalized per-stage timing, even if StageResult.trace is empty
|
| 156 |
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
| 157 |
traces.append(
|
| 158 |
self._mk_trace(
|
|
@@ -162,26 +140,24 @@ class Pipeline:
|
|
| 162 |
)
|
| 163 |
)
|
| 164 |
|
| 165 |
-
# Normalize inputs
|
| 166 |
schema_preview = schema_preview or ""
|
| 167 |
clarify_answers = clarify_answers or {}
|
| 168 |
|
| 169 |
try:
|
| 170 |
# --- 1) detector ---
|
| 171 |
-
|
| 172 |
questions = self.detector.detect(user_query, schema_preview)
|
| 173 |
-
|
| 174 |
is_amb = bool(questions)
|
| 175 |
-
stage_duration_ms.labels("detector").observe(
|
| 176 |
traces.append(
|
| 177 |
self._mk_trace(
|
| 178 |
stage="detector",
|
| 179 |
-
duration_ms=
|
| 180 |
summary=("ambiguous" if is_amb else "clear"),
|
| 181 |
notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
|
| 182 |
)
|
| 183 |
)
|
| 184 |
-
|
| 185 |
if questions:
|
| 186 |
pipeline_runs_total.labels(status="ambiguous").inc()
|
| 187 |
return FinalResult(
|
|
@@ -197,15 +173,15 @@ class Pipeline:
|
|
| 197 |
)
|
| 198 |
|
| 199 |
# --- 2) planner ---
|
| 200 |
-
|
| 201 |
r_plan = self._safe_stage(
|
| 202 |
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 203 |
)
|
| 204 |
-
|
| 205 |
-
stage_duration_ms.labels("planner").observe(
|
| 206 |
traces.extend(self._trace_list(r_plan))
|
| 207 |
if not getattr(r_plan, "trace", None):
|
| 208 |
-
_fallback_trace("planner",
|
| 209 |
if not r_plan.ok:
|
| 210 |
pipeline_runs_total.labels(status="error").inc()
|
| 211 |
return FinalResult(
|
|
@@ -221,7 +197,7 @@ class Pipeline:
|
|
| 221 |
)
|
| 222 |
|
| 223 |
# --- 3) generator ---
|
| 224 |
-
|
| 225 |
r_gen = self._safe_stage(
|
| 226 |
self.generator.run,
|
| 227 |
user_query=user_query,
|
|
@@ -229,11 +205,11 @@ class Pipeline:
|
|
| 229 |
plan_text=(r_plan.data or {}).get("plan"),
|
| 230 |
clarify_answers=clarify_answers,
|
| 231 |
)
|
| 232 |
-
|
| 233 |
-
stage_duration_ms.labels("generator").observe(
|
| 234 |
traces.extend(self._trace_list(r_gen))
|
| 235 |
if not getattr(r_gen, "trace", None):
|
| 236 |
-
_fallback_trace("generator",
|
| 237 |
if not r_gen.ok:
|
| 238 |
pipeline_runs_total.labels(status="error").inc()
|
| 239 |
return FinalResult(
|
|
@@ -251,14 +227,32 @@ class Pipeline:
|
|
| 251 |
sql = (r_gen.data or {}).get("sql")
|
| 252 |
rationale = (r_gen.data or {}).get("rationale")
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
# --- 4) safety ---
|
| 255 |
-
|
| 256 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 257 |
-
|
| 258 |
-
stage_duration_ms.labels("safety").observe(
|
| 259 |
traces.extend(self._trace_list(r_safe))
|
| 260 |
if not getattr(r_safe, "trace", None):
|
| 261 |
-
_fallback_trace("safety",
|
| 262 |
if not r_safe.ok:
|
| 263 |
pipeline_runs_total.labels(status="error").inc()
|
| 264 |
return FinalResult(
|
|
@@ -273,99 +267,112 @@ class Pipeline:
|
|
| 273 |
traces=self._normalize_traces(traces),
|
| 274 |
)
|
| 275 |
|
|
|
|
|
|
|
|
|
|
| 276 |
# --- 5) executor ---
|
| 277 |
-
|
| 278 |
-
r_exec = self._safe_stage(
|
| 279 |
-
|
| 280 |
-
)
|
| 281 |
-
exe_ms = (time.perf_counter() - t_exe0) * 1000.0
|
| 282 |
-
stage_duration_ms.labels("executor").observe(exe_ms)
|
| 283 |
traces.extend(self._trace_list(r_exec))
|
| 284 |
if not getattr(r_exec, "trace", None):
|
| 285 |
-
_fallback_trace("executor",
|
| 286 |
if not r_exec.ok and r_exec.error:
|
| 287 |
-
#
|
| 288 |
-
details.extend(r_exec.error)
|
| 289 |
|
| 290 |
# --- 6) verifier ---
|
| 291 |
-
|
| 292 |
r_ver = self._safe_stage(
|
| 293 |
-
self.verifier.run,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
-
|
| 296 |
-
stage_duration_ms.labels("verifier").observe(
|
| 297 |
traces.extend(self._trace_list(r_ver))
|
| 298 |
if not getattr(r_ver, "trace", None):
|
| 299 |
-
_fallback_trace("verifier",
|
| 300 |
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 301 |
|
| 302 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
if not verified:
|
| 304 |
for _attempt in range(2):
|
| 305 |
# repair
|
| 306 |
-
|
| 307 |
r_fix = self._safe_stage(
|
| 308 |
self.repair.run,
|
| 309 |
sql=sql,
|
| 310 |
error_msg="; ".join(details or ["unknown"]),
|
| 311 |
schema_preview=schema_preview,
|
| 312 |
)
|
| 313 |
-
|
| 314 |
-
stage_duration_ms.labels("repair").observe(
|
| 315 |
traces.extend(self._trace_list(r_fix))
|
| 316 |
if not getattr(r_fix, "trace", None):
|
| 317 |
-
_fallback_trace("repair",
|
| 318 |
if not r_fix.ok:
|
| 319 |
-
break
|
| 320 |
|
| 321 |
-
#
|
| 322 |
sql = (r_fix.data or {}).get("sql", sql)
|
| 323 |
|
| 324 |
-
# safety
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
stage_duration_ms.labels("safety").observe(
|
| 329 |
-
traces.extend(self._trace_list(
|
| 330 |
-
if not getattr(
|
| 331 |
-
_fallback_trace("safety",
|
| 332 |
-
if not
|
| 333 |
-
if
|
| 334 |
-
details.extend(
|
| 335 |
continue
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
details.extend(r_exec.error)
|
| 350 |
continue
|
| 351 |
|
| 352 |
-
# verifier
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
self.verifier.run,
|
|
|
|
|
|
|
|
|
|
| 356 |
)
|
| 357 |
-
|
| 358 |
-
stage_duration_ms.labels("verifier").observe(
|
| 359 |
-
traces.extend(self._trace_list(
|
| 360 |
-
if not getattr(
|
| 361 |
-
_fallback_trace("verifier",
|
| 362 |
verified = (
|
| 363 |
-
bool(
|
| 364 |
)
|
|
|
|
|
|
|
| 365 |
if verified:
|
| 366 |
break
|
| 367 |
|
| 368 |
-
# --- 8)
|
| 369 |
if (verified is None or not verified) and not details:
|
| 370 |
any_exec_ok = any(
|
| 371 |
t.get("stage") == "executor"
|
|
@@ -385,13 +392,24 @@ class Pipeline:
|
|
| 385 |
|
| 386 |
# --- 9) finalize ---
|
| 387 |
has_errors = bool(details)
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
|
|
|
|
|
|
| 393 |
else:
|
| 394 |
-
|
|
|
|
|
|
|
| 395 |
|
| 396 |
traces.append(
|
| 397 |
self._mk_trace(
|
|
@@ -399,8 +417,9 @@ class Pipeline:
|
|
| 399 |
duration_ms=0.0,
|
| 400 |
summary="finalize",
|
| 401 |
notes={
|
| 402 |
-
"final_verified": bool(
|
| 403 |
"details_len": len(details),
|
|
|
|
| 404 |
},
|
| 405 |
)
|
| 406 |
)
|
|
@@ -412,18 +431,18 @@ class Pipeline:
|
|
| 412 |
details=details or None,
|
| 413 |
sql=sql,
|
| 414 |
rationale=rationale,
|
| 415 |
-
verified=
|
| 416 |
questions=None,
|
| 417 |
traces=self._normalize_traces(traces),
|
| 418 |
)
|
| 419 |
|
| 420 |
except Exception:
|
| 421 |
-
# Any unexpected crash
|
| 422 |
pipeline_runs_total.labels(status="error").inc()
|
|
|
|
| 423 |
raise
|
| 424 |
|
| 425 |
finally:
|
| 426 |
-
# Always record total latency even on early
|
| 427 |
stage_duration_ms.labels("pipeline_total").observe(
|
| 428 |
(time.perf_counter() - t_all0) * 1000.0
|
| 429 |
)
|
|
|
|
| 31 |
|
| 32 |
class Pipeline:
|
| 33 |
"""
|
| 34 |
+
NL2SQL Copilot pipeline:
|
| 35 |
+
detector → planner → generator → safety → executor → verifier → (optional repair loop).
|
|
|
|
| 36 |
"""
|
| 37 |
|
| 38 |
def __init__(
|
|
|
|
| 53 |
self.executor = executor or NoOpExecutor()
|
| 54 |
self.verifier = verifier or NoOpVerifier()
|
| 55 |
self.repair = repair or NoOpRepair()
|
| 56 |
+
# If the verifier explicitly requires verification, enforce it in finalize.
|
| 57 |
+
self.require_verification = bool(getattr(self.verifier, "required", False))
|
| 58 |
|
| 59 |
+
# ---------------------------- helpers ----------------------------
|
| 60 |
@staticmethod
|
| 61 |
def _trace_list(*stages: Optional[StageResult]) -> List[dict]:
|
|
|
|
| 62 |
traces: List[dict] = []
|
| 63 |
for s in stages:
|
| 64 |
if not s:
|
| 65 |
continue
|
| 66 |
t = getattr(s, "trace", None)
|
| 67 |
if t is not None:
|
|
|
|
| 68 |
traces.append(getattr(t, "__dict__", t))
|
| 69 |
return traces
|
| 70 |
|
|
|
|
| 71 |
@staticmethod
|
| 72 |
def _mk_trace(
|
| 73 |
stage: str,
|
|
|
|
| 75 |
summary: str,
|
| 76 |
notes: Optional[Dict[str, Any]] = None,
|
| 77 |
) -> dict:
|
|
|
|
| 78 |
return {
|
| 79 |
"stage": stage,
|
| 80 |
"duration_ms": float(duration_ms),
|
|
|
|
| 84 |
|
| 85 |
@staticmethod
|
| 86 |
def _normalize_traces(traces: List[dict]) -> List[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
norm: List[dict] = []
|
| 88 |
for t in traces:
|
| 89 |
stage = str(t.get("stage", "unknown"))
|
|
|
|
| 92 |
dur_int = int(round(float(dur)))
|
| 93 |
except Exception:
|
| 94 |
dur_int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
notes = t.get("notes") or {}
|
| 96 |
+
summary = t.get("summary") or (
|
| 97 |
+
"failed" if (notes.get("error") or notes.get("errors")) else "ok"
|
| 98 |
+
)
|
| 99 |
payload = {
|
| 100 |
"stage": stage,
|
| 101 |
"duration_ms": dur_int,
|
| 102 |
"summary": summary,
|
| 103 |
"notes": notes,
|
| 104 |
}
|
| 105 |
+
for k in ("token_in", "token_out", "cost_usd"):
|
| 106 |
+
if k in t:
|
| 107 |
+
payload[k] = t[k]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
norm.append(payload)
|
| 109 |
return norm
|
| 110 |
|
|
|
|
| 111 |
@staticmethod
|
| 112 |
def _safe_stage(fn, **kwargs) -> StageResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
try:
|
| 114 |
r = fn(**kwargs)
|
| 115 |
if isinstance(r, StageResult):
|
|
|
|
| 119 |
tb = traceback.format_exc()
|
| 120 |
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
| 121 |
|
| 122 |
+
# ------------------------------ run ------------------------------
|
| 123 |
def run(
|
| 124 |
self,
|
| 125 |
*,
|
|
|
|
| 131 |
traces: List[dict] = []
|
| 132 |
details: List[str] = []
|
| 133 |
|
|
|
|
| 134 |
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
| 135 |
traces.append(
|
| 136 |
self._mk_trace(
|
|
|
|
| 140 |
)
|
| 141 |
)
|
| 142 |
|
|
|
|
| 143 |
schema_preview = schema_preview or ""
|
| 144 |
clarify_answers = clarify_answers or {}
|
| 145 |
|
| 146 |
try:
|
| 147 |
# --- 1) detector ---
|
| 148 |
+
t0 = time.perf_counter()
|
| 149 |
questions = self.detector.detect(user_query, schema_preview)
|
| 150 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 151 |
is_amb = bool(questions)
|
| 152 |
+
stage_duration_ms.labels("detector").observe(dt)
|
| 153 |
traces.append(
|
| 154 |
self._mk_trace(
|
| 155 |
stage="detector",
|
| 156 |
+
duration_ms=dt,
|
| 157 |
summary=("ambiguous" if is_amb else "clear"),
|
| 158 |
notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
|
| 159 |
)
|
| 160 |
)
|
|
|
|
| 161 |
if questions:
|
| 162 |
pipeline_runs_total.labels(status="ambiguous").inc()
|
| 163 |
return FinalResult(
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
# --- 2) planner ---
|
| 176 |
+
t0 = time.perf_counter()
|
| 177 |
r_plan = self._safe_stage(
|
| 178 |
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 179 |
)
|
| 180 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 181 |
+
stage_duration_ms.labels("planner").observe(dt)
|
| 182 |
traces.extend(self._trace_list(r_plan))
|
| 183 |
if not getattr(r_plan, "trace", None):
|
| 184 |
+
_fallback_trace("planner", dt, r_plan.ok)
|
| 185 |
if not r_plan.ok:
|
| 186 |
pipeline_runs_total.labels(status="error").inc()
|
| 187 |
return FinalResult(
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
# --- 3) generator ---
|
| 200 |
+
t0 = time.perf_counter()
|
| 201 |
r_gen = self._safe_stage(
|
| 202 |
self.generator.run,
|
| 203 |
user_query=user_query,
|
|
|
|
| 205 |
plan_text=(r_plan.data or {}).get("plan"),
|
| 206 |
clarify_answers=clarify_answers,
|
| 207 |
)
|
| 208 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 209 |
+
stage_duration_ms.labels("generator").observe(dt)
|
| 210 |
traces.extend(self._trace_list(r_gen))
|
| 211 |
if not getattr(r_gen, "trace", None):
|
| 212 |
+
_fallback_trace("generator", dt, r_gen.ok)
|
| 213 |
if not r_gen.ok:
|
| 214 |
pipeline_runs_total.labels(status="error").inc()
|
| 215 |
return FinalResult(
|
|
|
|
| 227 |
sql = (r_gen.data or {}).get("sql")
|
| 228 |
rationale = (r_gen.data or {}).get("rationale")
|
| 229 |
|
| 230 |
+
# Guard: empty SQL
|
| 231 |
+
if not sql or not str(sql).strip():
|
| 232 |
+
pipeline_runs_total.labels(status="error").inc()
|
| 233 |
+
traces.append(
|
| 234 |
+
self._mk_trace("generator", 0.0, "failed", {"reason": "empty_sql"})
|
| 235 |
+
)
|
| 236 |
+
return FinalResult(
|
| 237 |
+
ok=False,
|
| 238 |
+
ambiguous=False,
|
| 239 |
+
error=True,
|
| 240 |
+
details=["empty_sql"],
|
| 241 |
+
questions=None,
|
| 242 |
+
sql=None,
|
| 243 |
+
rationale=rationale,
|
| 244 |
+
verified=None,
|
| 245 |
+
traces=self._normalize_traces(traces),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
# --- 4) safety ---
|
| 249 |
+
t0 = time.perf_counter()
|
| 250 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 251 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 252 |
+
stage_duration_ms.labels("safety").observe(dt)
|
| 253 |
traces.extend(self._trace_list(r_safe))
|
| 254 |
if not getattr(r_safe, "trace", None):
|
| 255 |
+
_fallback_trace("safety", dt, r_safe.ok)
|
| 256 |
if not r_safe.ok:
|
| 257 |
pipeline_runs_total.labels(status="error").inc()
|
| 258 |
return FinalResult(
|
|
|
|
| 267 |
traces=self._normalize_traces(traces),
|
| 268 |
)
|
| 269 |
|
| 270 |
+
# Use sanitized SQL from safety
|
| 271 |
+
sql = (r_safe.data or {}).get("sql", sql)
|
| 272 |
+
|
| 273 |
# --- 5) executor ---
|
| 274 |
+
t0 = time.perf_counter()
|
| 275 |
+
r_exec = self._safe_stage(self.executor.run, sql=sql)
|
| 276 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 277 |
+
stage_duration_ms.labels("executor").observe(dt)
|
|
|
|
|
|
|
| 278 |
traces.extend(self._trace_list(r_exec))
|
| 279 |
if not getattr(r_exec, "trace", None):
|
| 280 |
+
_fallback_trace("executor", dt, r_exec.ok)
|
| 281 |
if not r_exec.ok and r_exec.error:
|
| 282 |
+
details.extend(r_exec.error) # soft: keep for repair/verifier context
|
|
|
|
| 283 |
|
| 284 |
# --- 6) verifier ---
|
| 285 |
+
t0 = time.perf_counter()
|
| 286 |
r_ver = self._safe_stage(
|
| 287 |
+
self.verifier.run,
|
| 288 |
+
sql=sql,
|
| 289 |
+
exec_result=(r_exec.data or {}),
|
| 290 |
+
adapter=getattr(
|
| 291 |
+
self.executor, "adapter", None
|
| 292 |
+
), # let verifier use adapter
|
| 293 |
)
|
| 294 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 295 |
+
stage_duration_ms.labels("verifier").observe(dt)
|
| 296 |
traces.extend(self._trace_list(r_ver))
|
| 297 |
if not getattr(r_ver, "trace", None):
|
| 298 |
+
_fallback_trace("verifier", dt, r_ver.ok)
|
| 299 |
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 300 |
|
| 301 |
+
# consume repaired SQL from verifier if any
|
| 302 |
+
if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
|
| 303 |
+
sql = r_ver.data["sql"]
|
| 304 |
+
|
| 305 |
+
# --- 7) repair loop (if not verified) ---
|
| 306 |
if not verified:
|
| 307 |
for _attempt in range(2):
|
| 308 |
# repair
|
| 309 |
+
t0 = time.perf_counter()
|
| 310 |
r_fix = self._safe_stage(
|
| 311 |
self.repair.run,
|
| 312 |
sql=sql,
|
| 313 |
error_msg="; ".join(details or ["unknown"]),
|
| 314 |
schema_preview=schema_preview,
|
| 315 |
)
|
| 316 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 317 |
+
stage_duration_ms.labels("repair").observe(dt)
|
| 318 |
traces.extend(self._trace_list(r_fix))
|
| 319 |
if not getattr(r_fix, "trace", None):
|
| 320 |
+
_fallback_trace("repair", dt, r_fix.ok)
|
| 321 |
if not r_fix.ok:
|
| 322 |
+
break
|
| 323 |
|
| 324 |
+
# update SQL
|
| 325 |
sql = (r_fix.data or {}).get("sql", sql)
|
| 326 |
|
| 327 |
+
# safety again
|
| 328 |
+
t0 = time.perf_counter()
|
| 329 |
+
r_safe2 = self._safe_stage(self.safety.run, sql=sql)
|
| 330 |
+
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 331 |
+
stage_duration_ms.labels("safety").observe(dt2)
|
| 332 |
+
traces.extend(self._trace_list(r_safe2))
|
| 333 |
+
if not getattr(r_safe2, "trace", None):
|
| 334 |
+
_fallback_trace("safety", dt2, r_safe2.ok)
|
| 335 |
+
if not r_safe2.ok:
|
| 336 |
+
if r_safe2.error:
|
| 337 |
+
details.extend(r_safe2.error)
|
| 338 |
continue
|
| 339 |
+
sql = (r_safe2.data or {}).get("sql", sql)
|
| 340 |
+
|
| 341 |
+
# executor again
|
| 342 |
+
t0 = time.perf_counter()
|
| 343 |
+
r_exec2 = self._safe_stage(self.executor.run, sql=sql)
|
| 344 |
+
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 345 |
+
stage_duration_ms.labels("executor").observe(dt2)
|
| 346 |
+
traces.extend(self._trace_list(r_exec2))
|
| 347 |
+
if not getattr(r_exec2, "trace", None):
|
| 348 |
+
_fallback_trace("executor", dt2, r_exec2.ok)
|
| 349 |
+
if not r_exec2.ok:
|
| 350 |
+
if r_exec2.error:
|
| 351 |
+
details.extend(r_exec2.error)
|
|
|
|
| 352 |
continue
|
| 353 |
|
| 354 |
+
# verifier again
|
| 355 |
+
t0 = time.perf_counter()
|
| 356 |
+
r_ver2 = self._safe_stage(
|
| 357 |
+
self.verifier.run,
|
| 358 |
+
sql=sql,
|
| 359 |
+
exec_result=(r_exec2.data or {}),
|
| 360 |
+
adapter=getattr(self.executor, "adapter", None),
|
| 361 |
)
|
| 362 |
+
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 363 |
+
stage_duration_ms.labels("verifier").observe(dt2)
|
| 364 |
+
traces.extend(self._trace_list(r_ver2))
|
| 365 |
+
if not getattr(r_ver2, "trace", None):
|
| 366 |
+
_fallback_trace("verifier", dt2, r_ver2.ok)
|
| 367 |
verified = (
|
| 368 |
+
bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok
|
| 369 |
)
|
| 370 |
+
if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
|
| 371 |
+
sql = r_ver2.data["sql"]
|
| 372 |
if verified:
|
| 373 |
break
|
| 374 |
|
| 375 |
+
# --- 8) optional soft auto-verify (executor success, no details) ---
|
| 376 |
if (verified is None or not verified) and not details:
|
| 377 |
any_exec_ok = any(
|
| 378 |
t.get("stage") == "executor"
|
|
|
|
| 392 |
|
| 393 |
# --- 9) finalize ---
|
| 394 |
has_errors = bool(details)
|
| 395 |
+
need_ver = bool(self.require_verification)
|
| 396 |
+
|
| 397 |
+
# base success condition
|
| 398 |
+
final_ok_by_verifier = bool(verified)
|
| 399 |
+
base_ok = (
|
| 400 |
+
bool(sql) and not has_errors and (final_ok_by_verifier or not need_ver)
|
| 401 |
+
)
|
| 402 |
+
ok = base_ok
|
| 403 |
+
err = (not ok) and has_errors
|
| 404 |
|
| 405 |
+
# align `verified` with baseline semantics:
|
| 406 |
+
# if verification is NOT required and pipeline is ok, report verified=True
|
| 407 |
+
if not need_ver and ok and not final_ok_by_verifier:
|
| 408 |
+
verified_final = True
|
| 409 |
else:
|
| 410 |
+
verified_final = bool(verified)
|
| 411 |
+
|
| 412 |
+
pipeline_runs_total.labels(status=("ok" if ok else "error")).inc()
|
| 413 |
|
| 414 |
traces.append(
|
| 415 |
self._mk_trace(
|
|
|
|
| 417 |
duration_ms=0.0,
|
| 418 |
summary="finalize",
|
| 419 |
notes={
|
| 420 |
+
"final_verified": bool(verified_final),
|
| 421 |
"details_len": len(details),
|
| 422 |
+
"need_verification": need_ver,
|
| 423 |
},
|
| 424 |
)
|
| 425 |
)
|
|
|
|
| 431 |
details=details or None,
|
| 432 |
sql=sql,
|
| 433 |
rationale=rationale,
|
| 434 |
+
verified=verified_final,
|
| 435 |
questions=None,
|
| 436 |
traces=self._normalize_traces(traces),
|
| 437 |
)
|
| 438 |
|
| 439 |
except Exception:
|
|
|
|
| 440 |
pipeline_runs_total.labels(status="error").inc()
|
| 441 |
+
# bubble up to make failures visible in tests and logs
|
| 442 |
raise
|
| 443 |
|
| 444 |
finally:
|
| 445 |
+
# Always record total latency, even on early return/exception
|
| 446 |
stage_duration_ms.labels("pipeline_total").observe(
|
| 447 |
(time.perf_counter() - t_all0) * 1000.0
|
| 448 |
)
|
nl2sql/verifier.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
| 3 |
import re
|
| 4 |
import time
|
| 5 |
-
from typing import Any, Iterable, List, Optional
|
| 6 |
|
| 7 |
import sqlglot
|
| 8 |
from sqlglot import expressions as exp
|
|
@@ -10,24 +9,65 @@ from sqlglot import expressions as exp
|
|
| 10 |
from nl2sql.types import StageResult, StageTrace
|
| 11 |
from nl2sql.metrics import (
|
| 12 |
verifier_checks_total,
|
| 13 |
-
stage_duration_ms,
|
| 14 |
verifier_failures_total,
|
| 15 |
)
|
| 16 |
|
| 17 |
|
| 18 |
def _ms(t0: float) -> int:
|
|
|
|
| 19 |
return int((time.perf_counter() - t0) * 1000)
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class Verifier:
|
| 23 |
name = "verifier"
|
| 24 |
|
| 25 |
-
#
|
| 26 |
_AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
|
| 27 |
|
| 28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
|
| 30 |
-
"""
|
| 31 |
stack = [node]
|
| 32 |
while stack:
|
| 33 |
cur = stack.pop()
|
|
@@ -43,6 +83,7 @@ class Verifier:
|
|
| 43 |
stack.append(it)
|
| 44 |
|
| 45 |
def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
|
|
|
|
| 46 |
for n in self._walk(tree):
|
| 47 |
if isinstance(n, exp.Select):
|
| 48 |
return n
|
|
@@ -50,27 +91,22 @@ class Verifier:
|
|
| 50 |
|
| 51 |
def _has_group_by(self, tree: exp.Expression) -> bool:
|
| 52 |
sel = self._first_select(tree)
|
| 53 |
-
if
|
| 54 |
-
return False
|
| 55 |
-
# sqlglot stores GROUP BY on Select.group
|
| 56 |
-
return bool(getattr(sel, "group", None))
|
| 57 |
|
| 58 |
def _is_distinct_projection(self, tree: exp.Expression) -> bool:
|
| 59 |
sel = self._first_select(tree)
|
| 60 |
if not sel:
|
| 61 |
return False
|
| 62 |
-
# DISTINCT may appear as Select.distinct or a Distinct node
|
| 63 |
if getattr(sel, "distinct", None):
|
| 64 |
return True
|
| 65 |
return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
|
| 66 |
|
| 67 |
def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
|
| 68 |
-
# If there is any OVER(...) window, aggregates without GROUP BY can be legitimate
|
| 69 |
return any(isinstance(n, exp.Window) for n in self._walk(tree))
|
| 70 |
|
| 71 |
def _expr_contains_agg(self, node: exp.Expression) -> bool:
|
| 72 |
-
"""True if
|
| 73 |
-
|
| 74 |
agg_type_names = (
|
| 75 |
"Count",
|
| 76 |
"Sum",
|
|
@@ -81,26 +117,24 @@ class Verifier:
|
|
| 81 |
"ArrayAgg",
|
| 82 |
"StringAgg",
|
| 83 |
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
t
|
| 87 |
-
if isinstance(t, type)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
if AGG_TYPES and any(isinstance(n, AGG_TYPES) for n in self._walk(node)):
|
| 93 |
return True
|
| 94 |
|
| 95 |
-
#
|
| 96 |
Anonymous = getattr(exp, "Anonymous", None)
|
| 97 |
func_like = (exp.Func,) + ((Anonymous,) if isinstance(Anonymous, type) else ())
|
| 98 |
-
AGG_NAMES = {"count", "sum", "avg", "min", "max"}
|
| 99 |
|
| 100 |
-
def
|
| 101 |
-
|
| 102 |
-
if isinstance(
|
| 103 |
-
return
|
| 104 |
this = getattr(n, "this", None)
|
| 105 |
if isinstance(this, str):
|
| 106 |
return this.lower()
|
|
@@ -110,82 +144,138 @@ class Verifier:
|
|
| 110 |
return (str(this) or "").lower()
|
| 111 |
|
| 112 |
for n in self._walk(node):
|
| 113 |
-
if isinstance(n, func_like) and
|
| 114 |
-
return True
|
| 115 |
-
|
| 116 |
-
return False
|
| 117 |
-
|
| 118 |
-
def _has_nonagg_column(self, node: exp.Expression) -> bool:
|
| 119 |
-
"""Subtree contains a column reference that is NOT inside an aggregate."""
|
| 120 |
-
# Check if there are any columns in this expression
|
| 121 |
-
columns = [n for n in self._walk(node) if isinstance(n, exp.Column)]
|
| 122 |
-
if not columns:
|
| 123 |
-
return False
|
| 124 |
-
|
| 125 |
-
# Check if all columns are inside aggregates
|
| 126 |
-
for col in columns:
|
| 127 |
-
# Walk up from column to see if it's inside an aggregate
|
| 128 |
-
# is_in_agg = False
|
| 129 |
-
# For simplicity, check if the entire expression contains both column and aggregate
|
| 130 |
-
# A more precise check would require parent tracking
|
| 131 |
-
if self._expr_contains_agg(node):
|
| 132 |
-
# This is a simplified check - if the node has both columns and aggregates,
|
| 133 |
-
# we need more complex logic to determine if columns are outside aggregates
|
| 134 |
-
return True
|
| 135 |
-
else:
|
| 136 |
-
# No aggregates, so if there are columns, they're non-aggregate
|
| 137 |
return True
|
| 138 |
return False
|
| 139 |
|
| 140 |
-
# ----------------------- Textual fallback helpers -------------------------
|
| 141 |
def _clean_sql_for_fn_scan(self, sql: str) -> str:
|
| 142 |
-
"""
|
| 143 |
s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
|
| 144 |
s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
|
| 145 |
s = re.sub(
|
| 146 |
r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
|
| 147 |
-
) # quoted strings
|
| 148 |
s = re.sub(r"\s+", " ", s).strip()
|
| 149 |
return s
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
def
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
return None
|
| 158 |
|
| 159 |
-
def
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
t0 = time.perf_counter()
|
| 169 |
issues: List[str] = []
|
|
|
|
| 170 |
|
| 171 |
-
#
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
|
|
|
|
|
|
| 176 |
if tree is None:
|
| 177 |
return StageResult(
|
| 178 |
ok=False,
|
| 179 |
error=["parse_error"],
|
| 180 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 181 |
)
|
| 182 |
-
|
| 183 |
-
# sqlglot may parse broken SQL as an "Unknown" or "Command" type
|
| 184 |
-
# Check if we got a proper SQL statement type
|
| 185 |
tree_type = type(tree).__name__
|
| 186 |
-
|
| 187 |
-
# Check for common sqlglot error indicators
|
| 188 |
-
# When sqlglot can't parse properly, it often creates Command or Unknown nodes
|
| 189 |
if tree_type in ("Command", "Unknown"):
|
| 190 |
verifier_checks_total.labels(ok="false").inc()
|
| 191 |
verifier_failures_total.labels(reason="parse_error").inc()
|
|
@@ -194,36 +284,6 @@ class Verifier:
|
|
| 194 |
error=["parse_error"],
|
| 195 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 196 |
)
|
| 197 |
-
|
| 198 |
-
# Also check if the tree has errors attribute (some versions of sqlglot)
|
| 199 |
-
if hasattr(tree, "errors") and tree.errors:
|
| 200 |
-
verifier_checks_total.labels(ok="false").inc()
|
| 201 |
-
verifier_failures_total.labels(reason="parse_error").inc()
|
| 202 |
-
return StageResult(
|
| 203 |
-
ok=False,
|
| 204 |
-
error=["parse_error"],
|
| 205 |
-
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
# Additional check: if it's not a recognized DML/DQL statement
|
| 209 |
-
valid_types = ("Select", "With", "Union", "Intersect", "Except", "Values")
|
| 210 |
-
if tree_type not in valid_types:
|
| 211 |
-
# This might be a parse error disguised as a different statement type
|
| 212 |
-
# Let's check if it looks like it should be a SELECT
|
| 213 |
-
sql_lower = sql.lower().strip()
|
| 214 |
-
if any(
|
| 215 |
-
sql_lower.startswith(kw)
|
| 216 |
-
for kw in ["selct", "slect", "selet", "seelct"]
|
| 217 |
-
):
|
| 218 |
-
# Common misspellings of SELECT
|
| 219 |
-
verifier_checks_total.labels(ok="false").inc()
|
| 220 |
-
verifier_failures_total.labels(reason="parse_error").inc()
|
| 221 |
-
return StageResult(
|
| 222 |
-
ok=False,
|
| 223 |
-
error=["parse_error"],
|
| 224 |
-
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
except Exception:
|
| 228 |
verifier_checks_total.labels(ok="false").inc()
|
| 229 |
verifier_failures_total.labels(reason="parse_error").inc()
|
|
@@ -233,29 +293,22 @@ class Verifier:
|
|
| 233 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 234 |
)
|
| 235 |
|
| 236 |
-
# 2) Semantic
|
| 237 |
try:
|
| 238 |
sel = self._first_select(tree)
|
| 239 |
if sel:
|
| 240 |
has_group = self._has_group_by(tree)
|
| 241 |
has_window = self._has_windowed_aggregate(tree)
|
| 242 |
is_distinct = self._is_distinct_projection(tree)
|
| 243 |
-
|
| 244 |
select_items = list(getattr(sel, "expressions", []) or [])
|
| 245 |
any_agg = any(self._expr_contains_agg(it) for it in select_items)
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
if has_cols and not has_aggs:
|
| 254 |
-
any_nonagg_col = True
|
| 255 |
-
break
|
| 256 |
-
|
| 257 |
-
# Core rule: aggregate + non-aggregate column without GROUP BY is an issue,
|
| 258 |
-
# unless DISTINCT or windowed aggregate makes it legitimate.
|
| 259 |
if (
|
| 260 |
any_agg
|
| 261 |
and any_nonagg_col
|
|
@@ -264,72 +317,111 @@ class Verifier:
|
|
| 264 |
verifier_failures_total.labels(reason="semantic_error").inc()
|
| 265 |
issues.append("aggregation_without_group_by")
|
| 266 |
except Exception as e:
|
| 267 |
-
# Don't crash the verifier; surface a soft issue and let fallback run
|
| 268 |
verifier_failures_total.labels(reason="semantic_error").inc()
|
| 269 |
issues.append(f"semantic_check_error:{e!s}")
|
| 270 |
-
|
| 271 |
-
#
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
r"\bselect\s+distinct\b", cleaned, re.IGNORECASE
|
| 280 |
)
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
verifier_failures_total.labels(
|
| 295 |
-
reason="
|
| 296 |
).inc()
|
| 297 |
issues.append("aggregation_without_group_by")
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
# 4)
|
|
|
|
|
|
|
| 303 |
try:
|
| 304 |
-
exec_result
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
verifier_failures_total.labels(reason="preview_exec_error").inc()
|
| 309 |
-
issues.append(f"exec_error:{
|
| 310 |
except Exception as e:
|
| 311 |
verifier_failures_total.labels(reason="preview_exec_error").inc()
|
| 312 |
issues.append(f"exec_exception:{e!s}")
|
| 313 |
|
| 314 |
-
# 5) Final
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
return StageResult(
|
| 319 |
ok=False,
|
| 320 |
-
error=issues,
|
| 321 |
trace=StageTrace(
|
| 322 |
stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
|
| 323 |
),
|
| 324 |
)
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
def run(self, *, sql: str, adapter: Any) -> StageResult:
|
| 335 |
-
return self.verify(sql, adapter=adapter)
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
import re
|
| 3 |
import time
|
| 4 |
+
from typing import Any, Iterable, List, Optional, Dict, Tuple
|
| 5 |
|
| 6 |
import sqlglot
|
| 7 |
from sqlglot import expressions as exp
|
|
|
|
| 9 |
from nl2sql.types import StageResult, StageTrace
|
| 10 |
from nl2sql.metrics import (
|
| 11 |
verifier_checks_total,
|
|
|
|
| 12 |
verifier_failures_total,
|
| 13 |
)
|
| 14 |
|
| 15 |
|
| 16 |
def _ms(t0: float) -> int:
|
| 17 |
+
"""Return elapsed milliseconds since t0, as int."""
|
| 18 |
return int((time.perf_counter() - t0) * 1000)
|
| 19 |
|
| 20 |
|
| 21 |
+
# ---------------- Small Levenshtein distance for schema matching ----------------
|
| 22 |
+
def _lev(a: str, b: str) -> int:
|
| 23 |
+
n = len(b)
|
| 24 |
+
|
| 25 |
+
dp = list(range(n + 1))
|
| 26 |
+
for i, ca in enumerate(a, 1):
|
| 27 |
+
prev, dp[0] = dp[0], i
|
| 28 |
+
for j, cb in enumerate(b, 1):
|
| 29 |
+
cur = min(
|
| 30 |
+
dp[j] + 1, # delete
|
| 31 |
+
dp[j - 1] + 1, # insert
|
| 32 |
+
prev + (0 if ca == cb else 1), # replace
|
| 33 |
+
)
|
| 34 |
+
prev, dp[j] = dp[j], cur
|
| 35 |
+
return dp[n]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _closest(name: str, candidates: List[str]) -> Tuple[str, int]:
|
| 39 |
+
"""Find the closest match (by edit distance) for a given name."""
|
| 40 |
+
best, dist = name, 10**9
|
| 41 |
+
for c in candidates:
|
| 42 |
+
d = _lev(name.lower(), c.lower())
|
| 43 |
+
if d < dist:
|
| 44 |
+
best, dist = c, d
|
| 45 |
+
return best, dist
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _maybe_singular(plural: str, tables: List[str]) -> Optional[str]:
|
| 49 |
+
"""Simple singularization heuristic: 'singers' -> 'singer'."""
|
| 50 |
+
if plural.endswith("s"):
|
| 51 |
+
cand = plural[:-1]
|
| 52 |
+
if cand in tables:
|
| 53 |
+
return cand
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ---------------- Verifier with schema-aware repair ----------------
|
| 58 |
class Verifier:
|
| 59 |
name = "verifier"
|
| 60 |
|
| 61 |
+
# Aggregate call detector used by both AST and regex fallbacks
|
| 62 |
_AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
|
| 63 |
|
| 64 |
+
# Fast token sanity: require SELECT and FROM to exist in the cleaned SQL
|
| 65 |
+
_REQ_SELECT = re.compile(r"\bselect\b", re.IGNORECASE)
|
| 66 |
+
_REQ_FROM = re.compile(r"\bfrom\b", re.IGNORECASE)
|
| 67 |
+
|
| 68 |
+
# ---------- AST helpers ----------
|
| 69 |
def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
|
| 70 |
+
"""Depth-first traversal of a SQLGlot AST."""
|
| 71 |
stack = [node]
|
| 72 |
while stack:
|
| 73 |
cur = stack.pop()
|
|
|
|
| 83 |
stack.append(it)
|
| 84 |
|
| 85 |
def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
|
| 86 |
+
"""Return the first SELECT node from the AST (if any)."""
|
| 87 |
for n in self._walk(tree):
|
| 88 |
if isinstance(n, exp.Select):
|
| 89 |
return n
|
|
|
|
| 91 |
|
| 92 |
def _has_group_by(self, tree: exp.Expression) -> bool:
|
| 93 |
sel = self._first_select(tree)
|
| 94 |
+
return bool(getattr(sel, "group", None)) if sel else False
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
def _is_distinct_projection(self, tree: exp.Expression) -> bool:
|
| 97 |
sel = self._first_select(tree)
|
| 98 |
if not sel:
|
| 99 |
return False
|
|
|
|
| 100 |
if getattr(sel, "distinct", None):
|
| 101 |
return True
|
| 102 |
return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
|
| 103 |
|
| 104 |
def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
|
|
|
|
| 105 |
return any(isinstance(n, exp.Window) for n in self._walk(tree))
|
| 106 |
|
| 107 |
def _expr_contains_agg(self, node: exp.Expression) -> bool:
|
| 108 |
+
"""Return True if an expression contains an aggregate function."""
|
| 109 |
+
agg_names = {"count", "sum", "avg", "min", "max"}
|
| 110 |
agg_type_names = (
|
| 111 |
"Count",
|
| 112 |
"Sum",
|
|
|
|
| 117 |
"ArrayAgg",
|
| 118 |
"StringAgg",
|
| 119 |
)
|
| 120 |
+
agg_types = tuple(
|
| 121 |
+
t
|
| 122 |
+
for t in (getattr(exp, n, None) for n in agg_type_names)
|
| 123 |
+
if isinstance(t, type)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# AST type-based check (preferred)
|
| 127 |
+
if agg_types and any(isinstance(n, agg_types) for n in self._walk(node)):
|
|
|
|
| 128 |
return True
|
| 129 |
|
| 130 |
+
# Fallback: function-like name check
|
| 131 |
Anonymous = getattr(exp, "Anonymous", None)
|
| 132 |
func_like = (exp.Func,) + ((Anonymous,) if isinstance(Anonymous, type) else ())
|
|
|
|
| 133 |
|
| 134 |
+
def _fname(n: exp.Expression) -> str:
|
| 135 |
+
nm = getattr(n, "name", None)
|
| 136 |
+
if isinstance(nm, str) and nm:
|
| 137 |
+
return nm.lower()
|
| 138 |
this = getattr(n, "this", None)
|
| 139 |
if isinstance(this, str):
|
| 140 |
return this.lower()
|
|
|
|
| 144 |
return (str(this) or "").lower()
|
| 145 |
|
| 146 |
for n in self._walk(node):
|
| 147 |
+
if isinstance(n, func_like) and _fname(n) in agg_names:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
return True
|
| 149 |
return False
|
| 150 |
|
|
|
|
| 151 |
def _clean_sql_for_fn_scan(self, sql: str) -> str:
|
| 152 |
+
"""Normalize SQL before scanning for function names or keywords."""
|
| 153 |
s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
|
| 154 |
s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
|
| 155 |
s = re.sub(
|
| 156 |
r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
|
| 157 |
+
) # quoted strings
|
| 158 |
s = re.sub(r"\s+", " ", s).strip()
|
| 159 |
return s
|
| 160 |
|
| 161 |
+
# ---------------- Schema-Guard Repair ----------------
|
| 162 |
+
def _schema_dict(self, adapter: Any) -> Optional[Dict[str, List[str]]]:
|
| 163 |
+
"""Fetch schema dict {table: [columns]} from adapter if available."""
|
| 164 |
+
if not adapter:
|
| 165 |
+
return None
|
| 166 |
+
get = getattr(adapter, "schema_dict", None)
|
| 167 |
+
if callable(get):
|
| 168 |
+
try:
|
| 169 |
+
d = get()
|
| 170 |
+
if isinstance(d, dict):
|
| 171 |
+
return {str(k): list(v) for k, v in d.items()}
|
| 172 |
+
except Exception:
|
| 173 |
+
return None
|
| 174 |
return None
|
| 175 |
|
| 176 |
+
def _repair_with_schema(
|
| 177 |
+
self, sql: str, schema: Dict[str, List[str]]
|
| 178 |
+
) -> Tuple[str, bool, List[str]]:
|
| 179 |
+
"""Try to fix table/column names using schema similarity (singularize + closest edit-distance <= 2)."""
|
| 180 |
+
notes: List[str] = []
|
| 181 |
+
try:
|
| 182 |
+
ast = sqlglot.parse_one(sql)
|
| 183 |
+
except Exception as e:
|
| 184 |
+
return sql, False, [f"parse_error:{e!s}"]
|
| 185 |
+
|
| 186 |
+
tables = list(schema.keys())
|
| 187 |
+
changed = False
|
| 188 |
+
|
| 189 |
+
# Fix table names
|
| 190 |
+
def _fix_table(node: exp.Expression) -> exp.Expression:
|
| 191 |
+
nonlocal changed
|
| 192 |
+
if isinstance(node, exp.Table):
|
| 193 |
+
orig = node.name
|
| 194 |
+
if orig in schema:
|
| 195 |
+
return node
|
| 196 |
+
s1 = _maybe_singular(orig, tables)
|
| 197 |
+
if s1:
|
| 198 |
+
changed = True
|
| 199 |
+
return exp.Table(this=sqlglot.to_identifier(s1))
|
| 200 |
+
best, dist = _closest(orig, tables)
|
| 201 |
+
if dist <= 2:
|
| 202 |
+
changed = True
|
| 203 |
+
return exp.Table(this=sqlglot.to_identifier(best))
|
| 204 |
+
return node
|
| 205 |
+
|
| 206 |
+
ast = ast.transform(_fix_table)
|
| 207 |
+
|
| 208 |
+
# Fix column names
|
| 209 |
+
def _fix_col(node: exp.Expression) -> exp.Expression:
|
| 210 |
+
nonlocal changed
|
| 211 |
+
if isinstance(node, exp.Column):
|
| 212 |
+
name = node.name
|
| 213 |
+
if not name:
|
| 214 |
+
return node
|
| 215 |
+
tbl = node.table
|
| 216 |
+
if tbl and tbl in schema:
|
| 217 |
+
candidates = schema[tbl]
|
| 218 |
+
else:
|
| 219 |
+
candidates = [c for cols in schema.values() for c in cols]
|
| 220 |
+
if name in candidates:
|
| 221 |
+
return node
|
| 222 |
+
best, dist = _closest(name, candidates) if candidates else (name, 99)
|
| 223 |
+
if dist <= 2:
|
| 224 |
+
changed = True
|
| 225 |
+
node.set("this", sqlglot.to_identifier(best))
|
| 226 |
+
return node
|
| 227 |
+
|
| 228 |
+
ast = ast.transform(_fix_col)
|
| 229 |
+
|
| 230 |
+
if not changed:
|
| 231 |
+
return sql, True, notes
|
| 232 |
|
| 233 |
+
try:
|
| 234 |
+
repaired = ast.sql(dialect="sqlite")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
return sql, False, notes + [f"rebuild_error:{e!s}"]
|
| 237 |
+
|
| 238 |
+
notes.append("schema_guard_repair")
|
| 239 |
+
return repaired, True, notes
|
| 240 |
+
|
| 241 |
+
# ---------------- Main verifier logic ----------------
|
| 242 |
+
def verify(
|
| 243 |
+
self, sql: str, *, exec_result: Any = None, adapter: Any = None
|
| 244 |
+
) -> StageResult:
|
| 245 |
+
"""
|
| 246 |
+
Verify syntax, basic semantics, and optionally schema correctness and preview-execution.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
StageResult with:
|
| 250 |
+
- ok: boolean
|
| 251 |
+
- data: may include {"verified": True, "sql": <repaired_sql>}
|
| 252 |
+
- trace: StageTrace(stage="verifier", duration_ms=...)
|
| 253 |
+
"""
|
| 254 |
t0 = time.perf_counter()
|
| 255 |
issues: List[str] = []
|
| 256 |
+
repaired_sql = None
|
| 257 |
|
| 258 |
+
# 0) Fast token sanity: must contain SELECT and FROM (handles typos like SELCT/FRM).
|
| 259 |
+
sql_scan = self._clean_sql_for_fn_scan(sql)
|
| 260 |
+
if not self._REQ_SELECT.search(sql_scan) or not self._REQ_FROM.search(sql_scan):
|
| 261 |
+
verifier_checks_total.labels(ok="false").inc()
|
| 262 |
+
verifier_failures_total.labels(reason="parse_error").inc()
|
| 263 |
+
return StageResult(
|
| 264 |
+
ok=False,
|
| 265 |
+
error=["parse_error"],
|
| 266 |
+
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 267 |
+
)
|
| 268 |
|
| 269 |
+
# 1) Syntax validation via sqlglot
|
| 270 |
+
try:
|
| 271 |
+
tree = sqlglot.parse_one(sql, read=None)
|
| 272 |
if tree is None:
|
| 273 |
return StageResult(
|
| 274 |
ok=False,
|
| 275 |
error=["parse_error"],
|
| 276 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 277 |
)
|
|
|
|
|
|
|
|
|
|
| 278 |
tree_type = type(tree).__name__
|
|
|
|
|
|
|
|
|
|
| 279 |
if tree_type in ("Command", "Unknown"):
|
| 280 |
verifier_checks_total.labels(ok="false").inc()
|
| 281 |
verifier_failures_total.labels(reason="parse_error").inc()
|
|
|
|
| 284 |
error=["parse_error"],
|
| 285 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 286 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
except Exception:
|
| 288 |
verifier_checks_total.labels(ok="false").inc()
|
| 289 |
verifier_failures_total.labels(reason="parse_error").inc()
|
|
|
|
| 293 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 294 |
)
|
| 295 |
|
| 296 |
+
# 2) Semantic rule: avoid aggregate + non-aggregate mix without GROUP BY (unless DISTINCT/window)
|
| 297 |
try:
|
| 298 |
sel = self._first_select(tree)
|
| 299 |
if sel:
|
| 300 |
has_group = self._has_group_by(tree)
|
| 301 |
has_window = self._has_windowed_aggregate(tree)
|
| 302 |
is_distinct = self._is_distinct_projection(tree)
|
|
|
|
| 303 |
select_items = list(getattr(sel, "expressions", []) or [])
|
| 304 |
any_agg = any(self._expr_contains_agg(it) for it in select_items)
|
| 305 |
+
any_nonagg_col = any(
|
| 306 |
+
(
|
| 307 |
+
any(isinstance(n, exp.Column) for n in self._walk(it))
|
| 308 |
+
and not self._expr_contains_agg(it)
|
| 309 |
+
)
|
| 310 |
+
for it in select_items
|
| 311 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
if (
|
| 313 |
any_agg
|
| 314 |
and any_nonagg_col
|
|
|
|
| 317 |
verifier_failures_total.labels(reason="semantic_error").inc()
|
| 318 |
issues.append("aggregation_without_group_by")
|
| 319 |
except Exception as e:
|
|
|
|
| 320 |
verifier_failures_total.labels(reason="semantic_error").inc()
|
| 321 |
issues.append(f"semantic_check_error:{e!s}")
|
| 322 |
+
# 2b) Regex fallback for aggregate + non-aggregate without GROUP BY.
|
| 323 |
+
# Skip if DISTINCT or any WINDOW (OVER ...) is present in the SELECT list.
|
| 324 |
+
try:
|
| 325 |
+
low = sql_scan.lower()
|
| 326 |
+
if "group by" not in low and "distinct" not in low:
|
| 327 |
+
m = re.search(
|
| 328 |
+
r"select\s+(?P<sel>.+?)\s+from\b",
|
| 329 |
+
sql_scan,
|
| 330 |
+
flags=re.IGNORECASE | re.DOTALL,
|
|
|
|
| 331 |
)
|
| 332 |
+
if m:
|
| 333 |
+
sel_clause = m.group("sel")
|
| 334 |
+
# If window functions are present, allow (COUNT(*) OVER (...), etc.)
|
| 335 |
+
if re.search(r"\bover\b", sel_clause, flags=re.IGNORECASE):
|
| 336 |
+
pass # windowed aggregates are acceptable without GROUP BY
|
| 337 |
+
else:
|
| 338 |
+
has_agg = bool(self._AGG_CALL_RE.search(sel_clause))
|
| 339 |
+
# Heuristic: presence of a comma OR a bare identifier besides pure aggregate-only select
|
| 340 |
+
has_bare_col = "," in sel_clause or (
|
| 341 |
+
bool(re.search(r"\b[a-zA-Z_][\w.]*\b", sel_clause))
|
| 342 |
+
and not re.fullmatch(
|
| 343 |
+
r"\s*(count|sum|avg|min|max)\s*\([^)]*\)\s*",
|
| 344 |
+
sel_clause,
|
| 345 |
+
flags=re.IGNORECASE,
|
| 346 |
+
)
|
| 347 |
+
)
|
| 348 |
+
if (
|
| 349 |
+
has_agg
|
| 350 |
+
and has_bare_col
|
| 351 |
+
and "aggregation_without_group_by" not in issues
|
| 352 |
+
):
|
| 353 |
verifier_failures_total.labels(
|
| 354 |
+
reason="semantic_error"
|
| 355 |
).inc()
|
| 356 |
issues.append("aggregation_without_group_by")
|
| 357 |
+
except Exception:
|
| 358 |
+
# Non-fatal; AST path already attempted.
|
| 359 |
+
pass
|
| 360 |
+
|
| 361 |
+
# 3) Schema-based auto-repair (optional)
|
| 362 |
+
schema = self._schema_dict(adapter)
|
| 363 |
+
if schema:
|
| 364 |
+
fixed, ok_fix, notes = self._repair_with_schema(sql, schema)
|
| 365 |
+
if ok_fix is True and fixed != sql:
|
| 366 |
+
repaired_sql = fixed
|
| 367 |
+
if notes:
|
| 368 |
+
issues.extend(
|
| 369 |
+
[f"note:{n}" for n in notes if not n.startswith("parse_error")]
|
| 370 |
+
)
|
| 371 |
|
| 372 |
+
# 4) Preview execution check:
|
| 373 |
+
# - If exec_result is provided, use it directly
|
| 374 |
+
# - Otherwise, if adapter has execute_preview, run it
|
| 375 |
try:
|
| 376 |
+
if exec_result is not None:
|
| 377 |
+
er = exec_result
|
| 378 |
+
elif adapter is not None and hasattr(adapter, "execute_preview"):
|
| 379 |
+
er = adapter.execute_preview(repaired_sql or sql)
|
| 380 |
+
else:
|
| 381 |
+
er = {"ok": True}
|
| 382 |
+
|
| 383 |
+
ok_val = (
|
| 384 |
+
isinstance(er, dict) and isinstance(er.get("ok"), bool) and er["ok"]
|
| 385 |
+
)
|
| 386 |
+
if not ok_val:
|
| 387 |
+
msg = None
|
| 388 |
+
if isinstance(er, dict):
|
| 389 |
+
for k in ("error", "message", "detail"):
|
| 390 |
+
if k in er and er[k]:
|
| 391 |
+
msg = str(er[k])
|
| 392 |
+
break
|
| 393 |
verifier_failures_total.labels(reason="preview_exec_error").inc()
|
| 394 |
+
issues.append(f"exec_error:{msg or 'preview_failed'}")
|
| 395 |
except Exception as e:
|
| 396 |
verifier_failures_total.labels(reason="preview_exec_error").inc()
|
| 397 |
issues.append(f"exec_exception:{e!s}")
|
| 398 |
|
| 399 |
+
# 5) Final result and trace
|
| 400 |
+
is_ok: bool = (not issues) or all(i.startswith("note:") for i in issues)
|
| 401 |
+
ok_label: str = "true" if is_ok else "false"
|
| 402 |
+
verifier_checks_total.labels(ok=ok_label).inc()
|
| 403 |
+
|
| 404 |
+
if is_ok:
|
| 405 |
+
data: Dict[str, Any] = {"verified": True}
|
| 406 |
+
if repaired_sql:
|
| 407 |
+
data["sql"] = repaired_sql
|
| 408 |
+
return StageResult(
|
| 409 |
+
ok=True,
|
| 410 |
+
data=data,
|
| 411 |
+
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
return StageResult(
|
| 415 |
ok=False,
|
| 416 |
+
error=[i for i in issues if not i.startswith("note:")],
|
| 417 |
trace=StageTrace(
|
| 418 |
stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
|
| 419 |
),
|
| 420 |
)
|
| 421 |
|
| 422 |
+
# Public alias for backward compatibility
|
| 423 |
+
def run(
|
| 424 |
+
self, *, sql: str, exec_result: Any = None, adapter: Any = None
|
| 425 |
+
) -> StageResult:
|
| 426 |
+
"""Back-compat wrapper around verify()."""
|
| 427 |
+
return self.verify(sql, exec_result=exec_result, adapter=adapter)
|
|
|
|
|
|
|
|
|
|
|
|