Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
e3e0ac5
1
Parent(s):
3716701
refactor(core): trace schema upgrade, verifier/executor sync, benchmark plot polish
Browse files- adapters/llm/base.py +1 -1
- adapters/llm/openai_provider.py +48 -1
- benchmarks/evaluate_spider_pro.py +39 -0
- benchmarks/plot_results.py +31 -1
- benchmarks/results_pro/20251109-105728/eval.jsonl +5 -0
- benchmarks/results_pro/20251109-105728/latency_histogram.png +0 -0
- benchmarks/results_pro/20251109-105728/latency_per_stage.png +0 -0
- benchmarks/results_pro/20251109-105728/metrics_overview.png +0 -0
- benchmarks/results_pro/20251109-105728/results.csv +6 -0
- benchmarks/results_pro/20251109-105728/summary.json +21 -0
- benchmarks/results_pro/20251109-110149/eval.jsonl +5 -0
- benchmarks/results_pro/20251109-110149/latency_histogram.png +0 -0
- benchmarks/results_pro/20251109-110149/latency_per_stage.png +0 -0
- benchmarks/results_pro/20251109-110149/metrics_overview.png +0 -0
- benchmarks/results_pro/20251109-110149/results.csv +6 -0
- benchmarks/results_pro/20251109-110149/summary.json +21 -0
- nl2sql/executor.py +10 -2
- nl2sql/pipeline.py +299 -306
- nl2sql/planner.py +18 -17
- nl2sql/verifier.py +107 -391
adapters/llm/base.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Tuple, Dict, Any, Protocol
|
|
| 3 |
|
| 4 |
|
| 5 |
class LLMProvider(Protocol):
|
| 6 |
-
|
| 7 |
|
| 8 |
def plan(
|
| 9 |
self, *, user_query: str, schema_preview: str
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class LLMProvider(Protocol):
|
| 6 |
+
PROVIDER_ID: str
|
| 7 |
|
| 8 |
def plan(
|
| 9 |
self, *, user_query: str, schema_preview: str
|
adapters/llm/openai_provider.py
CHANGED
|
@@ -37,7 +37,11 @@ def _resolve_api_config() -> tuple[str, str, str]:
|
|
| 37 |
class OpenAIProvider(LLMProvider):
|
| 38 |
"""OpenAI LLM provider implementation."""
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def __init__(self) -> None:
|
| 43 |
"""Initialize OpenAI client with config from environment."""
|
|
@@ -46,6 +50,8 @@ class OpenAIProvider(LLMProvider):
|
|
| 46 |
os.environ["OPENAI_BASE_URL"] = base_url
|
| 47 |
self.client = OpenAI()
|
| 48 |
self.model = model
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def plan(
|
| 51 |
self, *, user_query: str, schema_preview: str
|
|
@@ -94,8 +100,20 @@ Create a step-by-step plan to answer this question with SQL."""
|
|
| 94 |
prompt_tokens = usage.prompt_tokens
|
| 95 |
completion_tokens = usage.completion_tokens
|
| 96 |
cost = self._estimate_cost(usage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
return (msg, prompt_tokens, completion_tokens, cost)
|
| 98 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
return (msg, 0, 0, 0.0)
|
| 100 |
|
| 101 |
def generate_sql(
|
|
@@ -197,12 +215,27 @@ Now generate the SQL for the given question:"""
|
|
| 197 |
if not sql:
|
| 198 |
raise ValueError("LLM returned empty 'sql'")
|
| 199 |
|
|
|
|
| 200 |
if usage:
|
| 201 |
prompt_tokens = usage.prompt_tokens
|
| 202 |
completion_tokens = usage.completion_tokens
|
| 203 |
cost = self._estimate_cost(usage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
return (sql, rationale, prompt_tokens, completion_tokens, cost)
|
| 205 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
return (sql, rationale, 0, 0, 0.0)
|
| 207 |
|
| 208 |
def _simplify_sql(self, sql: str) -> str:
|
|
@@ -307,8 +340,22 @@ Return the corrected SQL (keep it simple):"""
|
|
| 307 |
prompt_tokens = usage.prompt_tokens
|
| 308 |
completion_tokens = usage.completion_tokens
|
| 309 |
cost = self._estimate_cost(usage)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
return (fixed_sql, prompt_tokens, completion_tokens, cost)
|
| 311 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
return (fixed_sql, 0, 0, 0.0)
|
| 313 |
|
| 314 |
def _estimate_cost(self, usage: Any) -> float:
|
|
|
|
| 37 |
class OpenAIProvider(LLMProvider):
|
| 38 |
"""OpenAI LLM provider implementation."""
|
| 39 |
|
| 40 |
+
PROVIDER_ID = "openai"
|
| 41 |
+
|
| 42 |
+
def get_last_usage(self) -> dict[str, Any]:
|
| 43 |
+
"""Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
|
| 44 |
+
return dict(self._last_usage)
|
| 45 |
|
| 46 |
def __init__(self) -> None:
|
| 47 |
"""Initialize OpenAI client with config from environment."""
|
|
|
|
| 50 |
os.environ["OPENAI_BASE_URL"] = base_url
|
| 51 |
self.client = OpenAI()
|
| 52 |
self.model = model
|
| 53 |
+
# last call usage/metadata for tracing
|
| 54 |
+
self._last_usage: dict[str, Any] = {}
|
| 55 |
|
| 56 |
def plan(
|
| 57 |
self, *, user_query: str, schema_preview: str
|
|
|
|
| 100 |
prompt_tokens = usage.prompt_tokens
|
| 101 |
completion_tokens = usage.completion_tokens
|
| 102 |
cost = self._estimate_cost(usage)
|
| 103 |
+
self._last_usage = {
|
| 104 |
+
"kind": "plan",
|
| 105 |
+
"prompt_tokens": prompt_tokens,
|
| 106 |
+
"completion_tokens": completion_tokens,
|
| 107 |
+
"cost_usd": cost,
|
| 108 |
+
}
|
| 109 |
return (msg, prompt_tokens, completion_tokens, cost)
|
| 110 |
else:
|
| 111 |
+
self._last_usage = {
|
| 112 |
+
"kind": "plan",
|
| 113 |
+
"prompt_tokens": 0,
|
| 114 |
+
"completion_tokens": 0,
|
| 115 |
+
"cost_usd": 0.0,
|
| 116 |
+
}
|
| 117 |
return (msg, 0, 0, 0.0)
|
| 118 |
|
| 119 |
def generate_sql(
|
|
|
|
| 215 |
if not sql:
|
| 216 |
raise ValueError("LLM returned empty 'sql'")
|
| 217 |
|
| 218 |
+
sql_length = len(sql)
|
| 219 |
if usage:
|
| 220 |
prompt_tokens = usage.prompt_tokens
|
| 221 |
completion_tokens = usage.completion_tokens
|
| 222 |
cost = self._estimate_cost(usage)
|
| 223 |
+
self._last_usage = {
|
| 224 |
+
"kind": "generate",
|
| 225 |
+
"prompt_tokens": prompt_tokens,
|
| 226 |
+
"completion_tokens": completion_tokens,
|
| 227 |
+
"cost_usd": cost,
|
| 228 |
+
"sql_length": sql_length,
|
| 229 |
+
}
|
| 230 |
return (sql, rationale, prompt_tokens, completion_tokens, cost)
|
| 231 |
else:
|
| 232 |
+
self._last_usage = {
|
| 233 |
+
"kind": "generate",
|
| 234 |
+
"prompt_tokens": 0,
|
| 235 |
+
"completion_tokens": 0,
|
| 236 |
+
"cost_usd": 0.0,
|
| 237 |
+
"sql_length": sql_length,
|
| 238 |
+
}
|
| 239 |
return (sql, rationale, 0, 0, 0.0)
|
| 240 |
|
| 241 |
def _simplify_sql(self, sql: str) -> str:
|
|
|
|
| 340 |
prompt_tokens = usage.prompt_tokens
|
| 341 |
completion_tokens = usage.completion_tokens
|
| 342 |
cost = self._estimate_cost(usage)
|
| 343 |
+
self._last_usage = {
|
| 344 |
+
"kind": "repair",
|
| 345 |
+
"prompt_tokens": prompt_tokens,
|
| 346 |
+
"completion_tokens": completion_tokens,
|
| 347 |
+
"cost_usd": cost,
|
| 348 |
+
"sql_length": len(fixed_sql),
|
| 349 |
+
}
|
| 350 |
return (fixed_sql, prompt_tokens, completion_tokens, cost)
|
| 351 |
else:
|
| 352 |
+
self._last_usage = {
|
| 353 |
+
"kind": "repair",
|
| 354 |
+
"prompt_tokens": 0,
|
| 355 |
+
"completion_tokens": 0,
|
| 356 |
+
"cost_usd": 0.0,
|
| 357 |
+
"sql_length": len(fixed_sql),
|
| 358 |
+
}
|
| 359 |
return (fixed_sql, 0, 0, 0.0)
|
| 360 |
|
| 361 |
def _estimate_cost(self, usage: Any) -> float:
|
benchmarks/evaluate_spider_pro.py
CHANGED
|
@@ -206,6 +206,45 @@ def evaluate_sql(pred: str, gold: str, db: Path) -> Dict[str, float]:
|
|
| 206 |
return {"em": em, "sm": sm, "exec": exec_acc}
|
| 207 |
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# ---------------------- Dataclass + runner ------------------
|
| 210 |
|
| 211 |
|
|
|
|
| 206 |
return {"em": em, "sm": sm, "exec": exec_acc}
|
| 207 |
|
| 208 |
|
| 209 |
+
# ---------------------- Trace flatten helpers -------------------
|
| 210 |
+
def _flatten_trace_entry(d: Dict[str, Any]) -> Dict[str, Any]:
|
| 211 |
+
out = dict(d or {})
|
| 212 |
+
notes = out.pop("notes", {}) or {}
|
| 213 |
+
# promote selected keys to top-level for easier analysis
|
| 214 |
+
for k in (
|
| 215 |
+
"tokens_in",
|
| 216 |
+
"tokens_out",
|
| 217 |
+
"cost_usd",
|
| 218 |
+
"sql_length",
|
| 219 |
+
"row_count",
|
| 220 |
+
"verified",
|
| 221 |
+
"error_type",
|
| 222 |
+
"repair_attempts",
|
| 223 |
+
"skipped",
|
| 224 |
+
"col_count",
|
| 225 |
+
):
|
| 226 |
+
if k in notes:
|
| 227 |
+
out[k] = notes[k]
|
| 228 |
+
if notes:
|
| 229 |
+
out["notes"] = notes
|
| 230 |
+
return out
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _per_stage_ms(trace_list: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 234 |
+
acc = {s: 0.0 for s in STAGES}
|
| 235 |
+
cnt = {s: 0 for s in STAGES}
|
| 236 |
+
for t in trace_list:
|
| 237 |
+
s = t.get("stage")
|
| 238 |
+
if s in acc:
|
| 239 |
+
ms = t.get("duration_ms", t.get("ms", 0.0))
|
| 240 |
+
try:
|
| 241 |
+
acc[s] += float(ms)
|
| 242 |
+
cnt[s] += 1
|
| 243 |
+
except Exception:
|
| 244 |
+
pass
|
| 245 |
+
return {s: round(acc[s] / cnt[s], 2) if cnt[s] else 0.0 for s in STAGES}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
# ---------------------- Dataclass + runner ------------------
|
| 249 |
|
| 250 |
|
benchmarks/plot_results.py
CHANGED
|
@@ -124,6 +124,35 @@ def plot_latency_per_stage(run: Path, summary: dict, rows: list[dict]) -> None:
|
|
| 124 |
plt.close()
|
| 125 |
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
def main() -> None:
|
| 128 |
run = _latest_run_dir()
|
| 129 |
print(f"📂 Using latest run: {run.name}")
|
|
@@ -132,8 +161,9 @@ def main() -> None:
|
|
| 132 |
plot_metrics_overview(run, summary)
|
| 133 |
plot_latency_hist(run, rows)
|
| 134 |
plot_latency_per_stage(run, summary, rows)
|
|
|
|
| 135 |
print(
|
| 136 |
-
"✅ Saved: metrics_overview.png, latency_histogram.png, latency_per_stage.png"
|
| 137 |
)
|
| 138 |
|
| 139 |
|
|
|
|
| 124 |
plt.close()
|
| 125 |
|
| 126 |
|
| 127 |
+
def plot_errors_overview(run: Path) -> None:
|
| 128 |
+
p = run / "trace.jsonl"
|
| 129 |
+
if not p.exists():
|
| 130 |
+
return
|
| 131 |
+
from collections import Counter
|
| 132 |
+
|
| 133 |
+
counts: Counter[str] = Counter()
|
| 134 |
+
with p.open("r", encoding="utf-8") as f:
|
| 135 |
+
for line in f:
|
| 136 |
+
try:
|
| 137 |
+
obj = json.loads(line)
|
| 138 |
+
except Exception:
|
| 139 |
+
continue
|
| 140 |
+
for t in obj.get("trace", []):
|
| 141 |
+
et = t.get("error_type")
|
| 142 |
+
if et:
|
| 143 |
+
counts[et] += 1
|
| 144 |
+
if not counts:
|
| 145 |
+
return
|
| 146 |
+
labels, values = zip(*sorted(counts.items(), key=lambda x: x[1], reverse=True))
|
| 147 |
+
plt.figure(figsize=(9, 4))
|
| 148 |
+
plt.bar(labels, values)
|
| 149 |
+
plt.title("Errors by Type")
|
| 150 |
+
plt.ylabel("Count")
|
| 151 |
+
plt.tight_layout()
|
| 152 |
+
plt.savefig(run / "errors_overview.png")
|
| 153 |
+
plt.close()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
def main() -> None:
|
| 157 |
run = _latest_run_dir()
|
| 158 |
print(f"📂 Using latest run: {run.name}")
|
|
|
|
| 161 |
plot_metrics_overview(run, summary)
|
| 162 |
plot_latency_hist(run, rows)
|
| 163 |
plot_latency_per_stage(run, summary, rows)
|
| 164 |
+
plot_errors_overview(run)
|
| 165 |
print(
|
| 166 |
+
"✅ Saved: metrics_overview.png, latency_histogram.png, latency_per_stage.png, errors_overview.png"
|
| 167 |
)
|
| 168 |
|
| 169 |
|
benchmarks/results_pro/20251109-105728/eval.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 11836, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 6838, "summary": "ok", "notes": {"len_plan": 1460}, "token_in": 265, "token_out": 356, "cost_usd": 0.00025334999999999995, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3409, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 838, "token_out": 19, "cost_usd": 0.0001371, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 832, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 744, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
| 2 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 10414, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 5346, "summary": "ok", "notes": {"len_plan": 1385}, "token_in": 266, "token_out": 334, "cost_usd": 0.00024029999999999999, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3352, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 817, "token_out": 19, "cost_usd": 0.00013394999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 4, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 871, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 831, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
| 3 |
+
{"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}, "skipped": false}]}
|
| 4 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 13807, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 8248, "summary": "ok", "notes": {"len_plan": 1415}, "token_in": 276, "token_out": 335, "cost_usd": 0.0002424, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3686, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 828, "token_out": 37, "cost_usd": 0.00014639999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 55}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 960, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64, "attempt": 1}, "token_in": 320, "token_out": 21, "cost_usd": 6.0599999999999996e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 901, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64, "attempt": 2}, "token_in": 323, "token_out": 21, "cost_usd": 6.104999999999999e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
| 5 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 13396, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7141, "summary": "ok", "notes": {"len_plan": 1569}, "token_in": 274, "token_out": 404, "cost_usd": 0.0002835, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 4139, "summary": "ok", "notes": {"rationale_len": 87}, "token_in": 895, "token_out": 46, "cost_usd": 0.00016184999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 937, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80, "attempt": 1}, "token_in": 328, "token_out": 24, "cost_usd": 6.36e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 80}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1160, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72, "attempt": 2}, "token_in": 332, "token_out": 21, "cost_usd": 6.24e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
benchmarks/results_pro/20251109-105728/latency_histogram.png
ADDED
|
benchmarks/results_pro/20251109-105728/latency_per_stage.png
ADDED
|
benchmarks/results_pro/20251109-105728/metrics_overview.png
ADDED
|
benchmarks/results_pro/20251109-105728/results.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
db_id,query,ok,em,sm,exec_acc,latency_ms
|
| 2 |
+
concert_singer,"How many singers do we have?",✅,1.0,1.0,1.0,11836
|
| 3 |
+
concert_singer,"What is the total number of singers?",✅,1.0,1.0,1.0,10414
|
| 4 |
+
concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",✅,0.0,0.0,0.0,0
|
| 5 |
+
concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",✅,0.0,1.0,1.0,13807
|
| 6 |
+
concert_singer,"What is the average, minimum, and maximum age of all singers from France?",✅,0.0,1.0,1.0,13396
|
benchmarks/results_pro/20251109-105728/summary.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-11-09T10:58:17",
|
| 3 |
+
"split": "dev",
|
| 4 |
+
"config": "configs/sqlite_pipeline.yaml",
|
| 5 |
+
"total": 5,
|
| 6 |
+
"success": 5,
|
| 7 |
+
"success_rate": 1.0,
|
| 8 |
+
"avg_latency_ms": 9890.6,
|
| 9 |
+
"p50_latency_ms": 11836.0,
|
| 10 |
+
"p95_latency_ms": 13724.8,
|
| 11 |
+
"EM": 0.4,
|
| 12 |
+
"SM": 0.8,
|
| 13 |
+
"ExecAcc": 0.8,
|
| 14 |
+
"detector_avg_ms": 0.0,
|
| 15 |
+
"planner_avg_ms": 6893.25,
|
| 16 |
+
"generator_avg_ms": 3646.5,
|
| 17 |
+
"safety_avg_ms": 1.67,
|
| 18 |
+
"executor_avg_ms": 1.33,
|
| 19 |
+
"verifier_avg_ms": 0.42,
|
| 20 |
+
"repair_avg_ms": 904.5
|
| 21 |
+
}
|
benchmarks/results_pro/20251109-110149/eval.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 12419, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7389, "summary": "ok", "notes": {"len_plan": 1297}, "token_in": 265, "token_out": 305, "cost_usd": 0.00022274999999999997, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3333, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 787, "token_out": 19, "cost_usd": 0.00012945, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 812, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 867, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
| 2 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 13653, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 8492, "summary": "ok", "notes": {"len_plan": 1444}, "token_in": 266, "token_out": 343, "cost_usd": 0.0002457, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3127, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 826, "token_out": 19, "cost_usd": 0.00013529999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 914, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1108, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
| 3 |
+
{"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}, "skipped": false}]}
|
| 4 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 12306, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 6684, "summary": "ok", "notes": {"len_plan": 1253}, "token_in": 276, "token_out": 287, "cost_usd": 0.0002136, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3456, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 780, "token_out": 37, "cost_usd": 0.0001392, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 55}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 911, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64, "attempt": 1}, "token_in": 320, "token_out": 21, "cost_usd": 6.0599999999999996e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1239, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64, "attempt": 2}, "token_in": 323, "token_out": 21, "cost_usd": 6.104999999999999e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
| 5 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 14824, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 9466, "summary": "ok", "notes": {"len_plan": 1418}, "token_in": 274, "token_out": 352, "cost_usd": 0.00025229999999999995, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 2949, "summary": "ok", "notes": {"rationale_len": 87}, "token_in": 843, "token_out": 46, "cost_usd": 0.00015404999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1139, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80, "attempt": 1}, "token_in": 328, "token_out": 24, "cost_usd": 6.36e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 80}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1250, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72, "attempt": 2}, "token_in": 332, "token_out": 21, "cost_usd": 6.24e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
|
benchmarks/results_pro/20251109-110149/latency_histogram.png
ADDED
|
benchmarks/results_pro/20251109-110149/latency_per_stage.png
ADDED
|
benchmarks/results_pro/20251109-110149/metrics_overview.png
ADDED
|
benchmarks/results_pro/20251109-110149/results.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
db_id,query,ok,em,sm,exec_acc,latency_ms
|
| 2 |
+
concert_singer,"How many singers do we have?",✅,1.0,1.0,1.0,12419
|
| 3 |
+
concert_singer,"What is the total number of singers?",✅,1.0,1.0,1.0,13653
|
| 4 |
+
concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",✅,0.0,0.0,0.0,0
|
| 5 |
+
concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",✅,0.0,1.0,1.0,12306
|
| 6 |
+
concert_singer,"What is the average, minimum, and maximum age of all singers from France?",✅,0.0,1.0,1.0,14824
|
benchmarks/results_pro/20251109-110149/summary.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-11-09T11:02:43",
|
| 3 |
+
"split": "dev",
|
| 4 |
+
"config": "configs/sqlite_pipeline.yaml",
|
| 5 |
+
"total": 5,
|
| 6 |
+
"success": 5,
|
| 7 |
+
"success_rate": 1.0,
|
| 8 |
+
"avg_latency_ms": 10640.4,
|
| 9 |
+
"p50_latency_ms": 12419.0,
|
| 10 |
+
"p95_latency_ms": 14589.8,
|
| 11 |
+
"EM": 0.4,
|
| 12 |
+
"SM": 0.8,
|
| 13 |
+
"ExecAcc": 0.8,
|
| 14 |
+
"detector_avg_ms": 0.0,
|
| 15 |
+
"planner_avg_ms": 8007.75,
|
| 16 |
+
"generator_avg_ms": 3216.25,
|
| 17 |
+
"safety_avg_ms": 2.0,
|
| 18 |
+
"executor_avg_ms": 1.25,
|
| 19 |
+
"verifier_avg_ms": 0.58,
|
| 20 |
+
"repair_avg_ms": 1030.0
|
| 21 |
+
}
|
nl2sql/executor.py
CHANGED
|
@@ -16,7 +16,11 @@ class Executor:
|
|
| 16 |
trace = StageTrace(
|
| 17 |
stage=self.name,
|
| 18 |
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 19 |
-
notes={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
return StageResult(
|
| 22 |
ok=True, data={"rows": rows, "columns": cols}, trace=trace
|
|
@@ -25,6 +29,10 @@ class Executor:
|
|
| 25 |
trace = StageTrace(
|
| 26 |
stage=self.name,
|
| 27 |
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 28 |
-
notes={
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
)
|
| 30 |
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|
|
|
|
| 16 |
trace = StageTrace(
|
| 17 |
stage=self.name,
|
| 18 |
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 19 |
+
notes={
|
| 20 |
+
"row_count": len(rows),
|
| 21 |
+
"col_count": len(cols),
|
| 22 |
+
"sql_length": len(sql or ""),
|
| 23 |
+
},
|
| 24 |
)
|
| 25 |
return StageResult(
|
| 26 |
ok=True, data={"rows": rows, "columns": cols}, trace=trace
|
|
|
|
| 29 |
trace = StageTrace(
|
| 30 |
stage=self.name,
|
| 31 |
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 32 |
+
notes={
|
| 33 |
+
"error": str(e),
|
| 34 |
+
"error_type": type(e).__name__,
|
| 35 |
+
"sql_length": len(sql or ""),
|
| 36 |
+
},
|
| 37 |
)
|
| 38 |
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|
nl2sql/pipeline.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
|
|
|
| 2 |
import traceback
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
-
from typing import Dict,
|
| 5 |
-
import
|
| 6 |
|
| 7 |
from nl2sql.types import StageResult
|
| 8 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
|
@@ -14,6 +18,7 @@ from nl2sql.verifier import Verifier
|
|
| 14 |
from nl2sql.repair import Repair
|
| 15 |
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 16 |
from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
@dataclass(frozen=True)
|
|
@@ -32,7 +37,7 @@ class FinalResult:
|
|
| 32 |
class Pipeline:
|
| 33 |
"""
|
| 34 |
NL2SQL Copilot pipeline:
|
| 35 |
-
detector
|
| 36 |
"""
|
| 37 |
|
| 38 |
def __init__(
|
|
@@ -53,7 +58,6 @@ class Pipeline:
|
|
| 53 |
self.executor = executor or NoOpExecutor()
|
| 54 |
self.verifier = verifier or NoOpVerifier()
|
| 55 |
self.repair = repair or NoOpRepair()
|
| 56 |
-
# If the verifier explicitly requires verification, enforce it in finalize.
|
| 57 |
self.require_verification = bool(getattr(self.verifier, "required", False))
|
| 58 |
|
| 59 |
# ---------------------------- helpers ----------------------------
|
|
@@ -95,9 +99,19 @@ class Pipeline:
|
|
| 95 |
except Exception:
|
| 96 |
dur_int = 0
|
| 97 |
notes = t.get("notes") or {}
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
payload = {
|
| 102 |
"stage": stage,
|
| 103 |
"duration_ms": dur_int,
|
|
@@ -125,12 +139,59 @@ class Pipeline:
|
|
| 125 |
try:
|
| 126 |
r = fn(**kwargs)
|
| 127 |
if isinstance(r, StageResult):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
return r
|
| 129 |
return StageResult(ok=True, data=r, trace=None)
|
| 130 |
except Exception as e:
|
| 131 |
tb = traceback.format_exc()
|
| 132 |
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
# ------------------------------ run ------------------------------
|
| 135 |
def run(
|
| 136 |
self,
|
|
@@ -139,329 +200,261 @@ class Pipeline:
|
|
| 139 |
schema_preview: str | None = None,
|
| 140 |
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 141 |
) -> FinalResult:
|
| 142 |
-
t_all0 = time.perf_counter()
|
| 143 |
traces: List[dict] = []
|
| 144 |
details: List[str] = []
|
| 145 |
-
|
| 146 |
-
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
| 147 |
-
traces.append(
|
| 148 |
-
self._mk_trace(
|
| 149 |
-
stage=stage_name,
|
| 150 |
-
duration_ms=dt_ms,
|
| 151 |
-
summary=("ok" if ok else "failed"),
|
| 152 |
-
)
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
schema_preview = schema_preview or ""
|
| 156 |
clarify_answers = clarify_answers or {}
|
| 157 |
|
| 158 |
-
|
| 159 |
-
# --- 1) detector ---
|
| 160 |
-
t0 = time.perf_counter()
|
| 161 |
-
questions = self.detector.detect(user_query, schema_preview)
|
| 162 |
-
dt = (time.perf_counter() - t0) * 1000.0
|
| 163 |
-
is_amb = bool(questions)
|
| 164 |
-
stage_duration_ms.labels("detector").observe(dt)
|
| 165 |
traces.append(
|
| 166 |
-
self._mk_trace(
|
| 167 |
-
stage="detector",
|
| 168 |
-
duration_ms=dt,
|
| 169 |
-
summary=("ambiguous" if is_amb else "clear"),
|
| 170 |
-
notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
|
| 171 |
-
)
|
| 172 |
)
|
| 173 |
-
if questions:
|
| 174 |
-
pipeline_runs_total.labels(status="ambiguous").inc()
|
| 175 |
-
return FinalResult(
|
| 176 |
-
ok=True,
|
| 177 |
-
ambiguous=True,
|
| 178 |
-
error=False,
|
| 179 |
-
details=[f"Ambiguities found: {len(questions)}"],
|
| 180 |
-
questions=questions,
|
| 181 |
-
sql=None,
|
| 182 |
-
rationale=None,
|
| 183 |
-
verified=None,
|
| 184 |
-
traces=self._normalize_traces(traces),
|
| 185 |
-
)
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
sql=None,
|
| 206 |
-
rationale=None,
|
| 207 |
-
verified=None,
|
| 208 |
-
traces=self._normalize_traces(traces),
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
# --- 3) generator ---
|
| 212 |
-
t0 = time.perf_counter()
|
| 213 |
-
r_gen = self._safe_stage(
|
| 214 |
-
self.generator.run,
|
| 215 |
-
user_query=user_query,
|
| 216 |
-
schema_preview=schema_preview,
|
| 217 |
-
plan_text=(r_plan.data or {}).get("plan"),
|
| 218 |
-
clarify_answers=clarify_answers,
|
| 219 |
)
|
| 220 |
-
dt = (time.perf_counter() - t0) * 1000.0
|
| 221 |
-
stage_duration_ms.labels("generator").observe(dt)
|
| 222 |
-
traces.extend(self._trace_list(r_gen))
|
| 223 |
-
if not getattr(r_gen, "trace", None):
|
| 224 |
-
_fallback_trace("generator", dt, r_gen.ok)
|
| 225 |
-
if not r_gen.ok:
|
| 226 |
-
pipeline_runs_total.labels(status="error").inc()
|
| 227 |
-
return FinalResult(
|
| 228 |
-
ok=False,
|
| 229 |
-
ambiguous=False,
|
| 230 |
-
error=True,
|
| 231 |
-
details=r_gen.error,
|
| 232 |
-
questions=None,
|
| 233 |
-
sql=None,
|
| 234 |
-
rationale=None,
|
| 235 |
-
verified=None,
|
| 236 |
-
traces=self._normalize_traces(traces),
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
sql = (r_gen.data or {}).get("sql")
|
| 240 |
-
rationale = (r_gen.data or {}).get("rationale")
|
| 241 |
-
|
| 242 |
-
# Guard: empty SQL
|
| 243 |
-
if not sql or not str(sql).strip():
|
| 244 |
-
pipeline_runs_total.labels(status="error").inc()
|
| 245 |
-
traces.append(
|
| 246 |
-
self._mk_trace(
|
| 247 |
-
"generator",
|
| 248 |
-
0.0,
|
| 249 |
-
"failed",
|
| 250 |
-
{"reason": "empty_sql", "error_type": "EmptySQL"},
|
| 251 |
-
)
|
| 252 |
-
)
|
| 253 |
-
return FinalResult(
|
| 254 |
-
ok=False,
|
| 255 |
-
ambiguous=False,
|
| 256 |
-
error=True,
|
| 257 |
-
details=["empty_sql"],
|
| 258 |
-
questions=None,
|
| 259 |
-
sql=None,
|
| 260 |
-
rationale=rationale,
|
| 261 |
-
verified=None,
|
| 262 |
-
traces=self._normalize_traces(traces),
|
| 263 |
-
)
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
# Use sanitized SQL from safety
|
| 288 |
-
sql = (r_safe.data or {}).get("sql", sql)
|
| 289 |
-
|
| 290 |
-
# --- 5) executor ---
|
| 291 |
-
t0 = time.perf_counter()
|
| 292 |
-
r_exec = self._safe_stage(self.executor.run, sql=sql)
|
| 293 |
-
dt = (time.perf_counter() - t0) * 1000.0
|
| 294 |
-
stage_duration_ms.labels("executor").observe(dt)
|
| 295 |
-
traces.extend(self._trace_list(r_exec))
|
| 296 |
-
if not getattr(r_exec, "trace", None):
|
| 297 |
-
_fallback_trace("executor", dt, r_exec.ok)
|
| 298 |
-
if not r_exec.ok and r_exec.error:
|
| 299 |
-
details.extend(r_exec.error) # soft: keep for repair/verifier context
|
| 300 |
-
|
| 301 |
-
# --- 6) verifier ---
|
| 302 |
-
t0 = time.perf_counter()
|
| 303 |
-
r_ver = self._safe_stage(
|
| 304 |
-
self.verifier.run,
|
| 305 |
-
sql=sql,
|
| 306 |
-
exec_result=(r_exec.data or {}),
|
| 307 |
-
adapter=getattr(
|
| 308 |
-
self.executor, "adapter", None
|
| 309 |
-
), # let verifier use adapter
|
| 310 |
)
|
| 311 |
-
dt = (time.perf_counter() - t0) * 1000.0
|
| 312 |
-
stage_duration_ms.labels("verifier").observe(dt)
|
| 313 |
-
traces.extend(self._trace_list(r_ver))
|
| 314 |
-
if not getattr(r_ver, "trace", None):
|
| 315 |
-
_fallback_trace("verifier", dt, r_ver.ok)
|
| 316 |
-
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 317 |
-
|
| 318 |
-
# consume repaired SQL from verifier if any
|
| 319 |
-
if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
|
| 320 |
-
sql = r_ver.data["sql"]
|
| 321 |
-
|
| 322 |
-
# --- 7) repair loop (if not verified) ---
|
| 323 |
-
if not verified:
|
| 324 |
-
for _attempt in range(2):
|
| 325 |
-
# repair
|
| 326 |
-
t0 = time.perf_counter()
|
| 327 |
-
r_fix = self._safe_stage(
|
| 328 |
-
self.repair.run,
|
| 329 |
-
sql=sql,
|
| 330 |
-
error_msg="; ".join(details or ["unknown"]),
|
| 331 |
-
schema_preview=schema_preview,
|
| 332 |
-
)
|
| 333 |
-
dt = (time.perf_counter() - t0) * 1000.0
|
| 334 |
-
stage_duration_ms.labels("repair").observe(dt)
|
| 335 |
-
traces.extend(self._trace_list(r_fix))
|
| 336 |
-
if not getattr(r_fix, "trace", None):
|
| 337 |
-
_fallback_trace("repair", dt, r_fix.ok)
|
| 338 |
-
# annotate attempt
|
| 339 |
-
traces[-1]["notes"]["attempt"] = _attempt + 1
|
| 340 |
-
if not r_fix.ok:
|
| 341 |
-
break
|
| 342 |
-
|
| 343 |
-
# update SQL
|
| 344 |
-
sql = (r_fix.data or {}).get("sql", sql)
|
| 345 |
-
|
| 346 |
-
# safety again
|
| 347 |
-
t0 = time.perf_counter()
|
| 348 |
-
r_safe2 = self._safe_stage(self.safety.run, sql=sql)
|
| 349 |
-
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 350 |
-
stage_duration_ms.labels("safety").observe(dt2)
|
| 351 |
-
traces.extend(self._trace_list(r_safe2))
|
| 352 |
-
if not getattr(r_safe2, "trace", None):
|
| 353 |
-
_fallback_trace("safety", dt2, r_safe2.ok)
|
| 354 |
-
if not r_safe2.ok:
|
| 355 |
-
if r_safe2.error:
|
| 356 |
-
details.extend(r_safe2.error)
|
| 357 |
-
continue
|
| 358 |
-
sql = (r_safe2.data or {}).get("sql", sql)
|
| 359 |
-
|
| 360 |
-
# executor again
|
| 361 |
-
t0 = time.perf_counter()
|
| 362 |
-
r_exec2 = self._safe_stage(self.executor.run, sql=sql)
|
| 363 |
-
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 364 |
-
stage_duration_ms.labels("executor").observe(dt2)
|
| 365 |
-
traces.extend(self._trace_list(r_exec2))
|
| 366 |
-
if not getattr(r_exec2, "trace", None):
|
| 367 |
-
_fallback_trace("executor", dt2, r_exec2.ok)
|
| 368 |
-
if not r_exec2.ok:
|
| 369 |
-
if r_exec2.error:
|
| 370 |
-
details.extend(r_exec2.error)
|
| 371 |
-
continue
|
| 372 |
-
|
| 373 |
-
# verifier again
|
| 374 |
-
t0 = time.perf_counter()
|
| 375 |
-
r_ver2 = self._safe_stage(
|
| 376 |
-
self.verifier.run,
|
| 377 |
-
sql=sql,
|
| 378 |
-
exec_result=(r_exec2.data or {}),
|
| 379 |
-
adapter=getattr(self.executor, "adapter", None),
|
| 380 |
-
)
|
| 381 |
-
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 382 |
-
stage_duration_ms.labels("verifier").observe(dt2)
|
| 383 |
-
traces.extend(self._trace_list(r_ver2))
|
| 384 |
-
if not getattr(r_ver2, "trace", None):
|
| 385 |
-
_fallback_trace("verifier", dt2, r_ver2.ok)
|
| 386 |
-
verified = (
|
| 387 |
-
bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok
|
| 388 |
-
)
|
| 389 |
-
if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
|
| 390 |
-
sql = r_ver2.data["sql"]
|
| 391 |
-
if verified:
|
| 392 |
-
break
|
| 393 |
-
|
| 394 |
-
# --- 8) optional soft auto-verify (executor success, no details) ---
|
| 395 |
-
if (verified is None or not verified) and not details:
|
| 396 |
-
any_exec_ok = any(
|
| 397 |
-
t.get("stage") == "executor"
|
| 398 |
-
and (t.get("notes") or {}).get("row_count")
|
| 399 |
-
for t in traces
|
| 400 |
-
)
|
| 401 |
-
if any_exec_ok:
|
| 402 |
-
traces.append(
|
| 403 |
-
self._mk_trace(
|
| 404 |
-
stage="pipeline",
|
| 405 |
-
duration_ms=0.0,
|
| 406 |
-
summary="auto-verified",
|
| 407 |
-
notes={"reason": "executor succeeded, verifier silent"},
|
| 408 |
-
)
|
| 409 |
-
)
|
| 410 |
-
verified = True
|
| 411 |
-
|
| 412 |
-
# --- 9) finalize ---
|
| 413 |
-
has_errors = bool(details)
|
| 414 |
-
need_ver = bool(self.require_verification)
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
)
|
| 421 |
-
ok = base_ok
|
| 422 |
-
err = (not ok) and has_errors
|
| 423 |
-
|
| 424 |
-
# align `verified` with baseline semantics:
|
| 425 |
-
# if verification is NOT required and pipeline is ok, report verified=True
|
| 426 |
-
if not need_ver and ok and not final_ok_by_verifier:
|
| 427 |
-
verified_final = True
|
| 428 |
-
else:
|
| 429 |
-
verified_final = bool(verified)
|
| 430 |
-
|
| 431 |
-
pipeline_runs_total.labels(status=("ok" if ok else "error")).inc()
|
| 432 |
|
|
|
|
|
|
|
|
|
|
| 433 |
traces.append(
|
| 434 |
self._mk_trace(
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
"final_verified": bool(verified_final),
|
| 440 |
-
"details_len": len(details),
|
| 441 |
-
"need_verification": need_ver,
|
| 442 |
-
},
|
| 443 |
)
|
| 444 |
)
|
| 445 |
-
|
| 446 |
return FinalResult(
|
| 447 |
-
ok=
|
| 448 |
ambiguous=False,
|
| 449 |
-
error=
|
| 450 |
-
details=
|
| 451 |
-
sql=sql,
|
| 452 |
-
rationale=rationale,
|
| 453 |
-
verified=verified_final,
|
| 454 |
questions=None,
|
|
|
|
|
|
|
|
|
|
| 455 |
traces=self._normalize_traces(traces),
|
| 456 |
)
|
| 457 |
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
pipeline_runs_total.labels(status="error").inc()
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
|
|
|
|
|
|
|
|
|
| 467 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nl2sql/pipeline.py
|
| 2 |
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
import traceback
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, Iterator, List, Optional
|
| 9 |
+
from dataclasses import replace
|
| 10 |
|
| 11 |
from nl2sql.types import StageResult
|
| 12 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
|
|
|
| 18 |
from nl2sql.repair import Repair
|
| 19 |
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 20 |
from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
|
| 21 |
+
from nl2sql.types import StageTrace
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass(frozen=True)
|
|
|
|
| 37 |
class Pipeline:
|
| 38 |
"""
|
| 39 |
NL2SQL Copilot pipeline:
|
| 40 |
+
detector -> planner -> generator -> safety -> executor -> verifier -> repair (optional).
|
| 41 |
"""
|
| 42 |
|
| 43 |
def __init__(
|
|
|
|
| 58 |
self.executor = executor or NoOpExecutor()
|
| 59 |
self.verifier = verifier or NoOpVerifier()
|
| 60 |
self.repair = repair or NoOpRepair()
|
|
|
|
| 61 |
self.require_verification = bool(getattr(self.verifier, "required", False))
|
| 62 |
|
| 63 |
# ---------------------------- helpers ----------------------------
|
|
|
|
| 99 |
except Exception:
|
| 100 |
dur_int = 0
|
| 101 |
notes = t.get("notes") or {}
|
| 102 |
+
|
| 103 |
+
summary = t.get("summary")
|
| 104 |
+
if not summary:
|
| 105 |
+
# ✅ final fix: default to ok unless explicitly failed
|
| 106 |
+
if (
|
| 107 |
+
notes.get("verified") is False
|
| 108 |
+
or notes.get("error")
|
| 109 |
+
or notes.get("errors")
|
| 110 |
+
):
|
| 111 |
+
summary = "failed"
|
| 112 |
+
else:
|
| 113 |
+
summary = "ok"
|
| 114 |
+
|
| 115 |
payload = {
|
| 116 |
"stage": stage,
|
| 117 |
"duration_ms": dur_int,
|
|
|
|
| 139 |
try:
|
| 140 |
r = fn(**kwargs)
|
| 141 |
if isinstance(r, StageResult):
|
| 142 |
+
# ensure trace always exists, rebuild if necessary
|
| 143 |
+
if not getattr(r, "trace", None):
|
| 144 |
+
new_trace_obj = StageTrace(
|
| 145 |
+
stage="auto", duration_ms=0, summary="ok", notes={}
|
| 146 |
+
)
|
| 147 |
+
r = replace(r, trace=new_trace_obj)
|
| 148 |
+
|
| 149 |
return r
|
| 150 |
return StageResult(ok=True, data=r, trace=None)
|
| 151 |
except Exception as e:
|
| 152 |
tb = traceback.format_exc()
|
| 153 |
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
| 154 |
|
| 155 |
+
@contextmanager
|
| 156 |
+
def stage_trace(
|
| 157 |
+
self, traces: List[dict], name: str, summary: str = ""
|
| 158 |
+
) -> Iterator[Dict[str, Any]]:
|
| 159 |
+
t0 = time.perf_counter()
|
| 160 |
+
notes: Dict[str, Any] = {}
|
| 161 |
+
try:
|
| 162 |
+
yield notes
|
| 163 |
+
except Exception as exc:
|
| 164 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 165 |
+
traces.append(
|
| 166 |
+
self._mk_trace(
|
| 167 |
+
name, dt, "failed", notes | {"error_type": type(exc).__name__}
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
raise
|
| 171 |
+
else:
|
| 172 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 173 |
+
traces.append(self._mk_trace(name, dt, "ok", notes))
|
| 174 |
+
|
| 175 |
+
def _call_verifier(
|
| 176 |
+
self,
|
| 177 |
+
verifier,
|
| 178 |
+
*,
|
| 179 |
+
sql: str,
|
| 180 |
+
exec_result: Dict[str, Any],
|
| 181 |
+
adapter: Any | None,
|
| 182 |
+
) -> StageResult:
|
| 183 |
+
# Prefer legacy/simple interface when available
|
| 184 |
+
if hasattr(verifier, "verify"):
|
| 185 |
+
return verifier.verify(sql, adapter=adapter)
|
| 186 |
+
|
| 187 |
+
# Fallback to richer interface (needs exec_result)
|
| 188 |
+
if hasattr(verifier, "run"):
|
| 189 |
+
return verifier.run(sql=sql, exec_result=exec_result, adapter=adapter)
|
| 190 |
+
|
| 191 |
+
return StageResult(
|
| 192 |
+
ok=False, data={"verified": False}, trace=None, error=["no_verifier_method"]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
# ------------------------------ run ------------------------------
|
| 196 |
def run(
|
| 197 |
self,
|
|
|
|
| 200 |
schema_preview: str | None = None,
|
| 201 |
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 202 |
) -> FinalResult:
|
|
|
|
| 203 |
traces: List[dict] = []
|
| 204 |
details: List[str] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
schema_preview = schema_preview or ""
|
| 206 |
clarify_answers = clarify_answers or {}
|
| 207 |
|
| 208 |
+
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
traces.append(
|
| 210 |
+
self._mk_trace(stage=stage_name, duration_ms=dt_ms, summary="ok")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
# 1) detector
|
| 214 |
+
t0 = time.perf_counter()
|
| 215 |
+
questions = self.detector.detect(user_query, schema_preview)
|
| 216 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 217 |
+
stage_duration_ms.labels("detector").observe(dt)
|
| 218 |
+
is_amb = bool(questions)
|
| 219 |
+
traces.append(
|
| 220 |
+
self._mk_trace(
|
| 221 |
+
"detector",
|
| 222 |
+
dt,
|
| 223 |
+
("ambiguous" if is_amb else "clear"),
|
| 224 |
+
{"ambiguous": is_amb, "questions_len": len(questions or [])},
|
| 225 |
)
|
| 226 |
+
)
|
| 227 |
+
if questions:
|
| 228 |
+
pipeline_runs_total.labels(status="ambiguous").inc()
|
| 229 |
+
return FinalResult(
|
| 230 |
+
ok=True,
|
| 231 |
+
ambiguous=True,
|
| 232 |
+
error=False,
|
| 233 |
+
details=[f"Ambiguities found: {len(questions)}"],
|
| 234 |
+
questions=questions,
|
| 235 |
+
sql=None,
|
| 236 |
+
rationale=None,
|
| 237 |
+
verified=None,
|
| 238 |
+
traces=self._normalize_traces(traces),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
# 2) planner
|
| 242 |
+
t0 = time.perf_counter()
|
| 243 |
+
r_plan = self._safe_stage(
|
| 244 |
+
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 245 |
+
)
|
| 246 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 247 |
+
stage_duration_ms.labels("planner").observe(dt)
|
| 248 |
+
traces.extend(self._trace_list(r_plan))
|
| 249 |
+
if not getattr(r_plan, "trace", None):
|
| 250 |
+
_fallback_trace("planner", dt, r_plan.ok)
|
| 251 |
+
if not r_plan.ok:
|
| 252 |
+
pipeline_runs_total.labels(status="error").inc()
|
| 253 |
+
return FinalResult(
|
| 254 |
+
ok=False,
|
| 255 |
+
ambiguous=False,
|
| 256 |
+
error=True,
|
| 257 |
+
details=r_plan.error,
|
| 258 |
+
questions=None,
|
| 259 |
+
sql=None,
|
| 260 |
+
rationale=None,
|
| 261 |
+
verified=None,
|
| 262 |
+
traces=self._normalize_traces(traces),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
# 3) generator
|
| 266 |
+
t0 = time.perf_counter()
|
| 267 |
+
r_gen = self._safe_stage(
|
| 268 |
+
self.generator.run,
|
| 269 |
+
user_query=user_query,
|
| 270 |
+
schema_preview=schema_preview,
|
| 271 |
+
plan_text=(r_plan.data or {}).get("plan"),
|
| 272 |
+
clarify_answers=clarify_answers,
|
| 273 |
+
)
|
| 274 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 275 |
+
stage_duration_ms.labels("generator").observe(dt)
|
| 276 |
+
traces.extend(self._trace_list(r_gen))
|
| 277 |
+
if not getattr(r_gen, "trace", None):
|
| 278 |
+
_fallback_trace("generator", dt, r_gen.ok)
|
| 279 |
+
if not r_gen.ok:
|
| 280 |
+
pipeline_runs_total.labels(status="error").inc()
|
| 281 |
+
return FinalResult(
|
| 282 |
+
ok=False,
|
| 283 |
+
ambiguous=False,
|
| 284 |
+
error=True,
|
| 285 |
+
details=r_gen.error,
|
| 286 |
+
questions=None,
|
| 287 |
+
sql=None,
|
| 288 |
+
rationale=None,
|
| 289 |
+
verified=None,
|
| 290 |
+
traces=self._normalize_traces(traces),
|
| 291 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
sql = (r_gen.data or {}).get("sql")
|
| 294 |
+
rationale = (r_gen.data or {}).get("rationale")
|
| 295 |
+
if not sql or not str(sql).strip():
|
| 296 |
traces.append(
|
| 297 |
self._mk_trace(
|
| 298 |
+
"generator",
|
| 299 |
+
dt,
|
| 300 |
+
"failed",
|
| 301 |
+
{"reason": "empty_sql", "error_type": "EmptySQL"},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
)
|
| 303 |
)
|
| 304 |
+
pipeline_runs_total.labels(status="error").inc()
|
| 305 |
return FinalResult(
|
| 306 |
+
ok=False,
|
| 307 |
ambiguous=False,
|
| 308 |
+
error=True,
|
| 309 |
+
details=["empty_sql"],
|
|
|
|
|
|
|
|
|
|
| 310 |
questions=None,
|
| 311 |
+
sql=None,
|
| 312 |
+
rationale=rationale,
|
| 313 |
+
verified=None,
|
| 314 |
traces=self._normalize_traces(traces),
|
| 315 |
)
|
| 316 |
|
| 317 |
+
# 4) safety
|
| 318 |
+
t0 = time.perf_counter()
|
| 319 |
+
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 320 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 321 |
+
stage_duration_ms.labels("safety").observe(dt)
|
| 322 |
+
traces.extend(self._trace_list(r_safe))
|
| 323 |
+
if not getattr(r_safe, "trace", None):
|
| 324 |
+
_fallback_trace("safety", dt, r_safe.ok)
|
| 325 |
+
if not r_safe.ok:
|
| 326 |
pipeline_runs_total.labels(status="error").inc()
|
| 327 |
+
return FinalResult(
|
| 328 |
+
ok=False,
|
| 329 |
+
ambiguous=False,
|
| 330 |
+
error=True,
|
| 331 |
+
details=r_safe.error,
|
| 332 |
+
questions=None,
|
| 333 |
+
sql=sql,
|
| 334 |
+
rationale=rationale,
|
| 335 |
+
verified=None,
|
| 336 |
+
traces=self._normalize_traces(traces),
|
| 337 |
)
|
| 338 |
+
sql = (r_safe.data or {}).get("sql", sql)
|
| 339 |
+
|
| 340 |
+
# 5) executor
|
| 341 |
+
t0 = time.perf_counter()
|
| 342 |
+
r_exec = self._safe_stage(self.executor.run, sql=sql)
|
| 343 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 344 |
+
stage_duration_ms.labels("executor").observe(dt)
|
| 345 |
+
traces.extend(self._trace_list(r_exec))
|
| 346 |
+
if not getattr(r_exec, "trace", None):
|
| 347 |
+
_fallback_trace("executor", dt, r_exec.ok)
|
| 348 |
+
if not r_exec.ok and r_exec.error:
|
| 349 |
+
details.extend(r_exec.error)
|
| 350 |
+
|
| 351 |
+
# 6) verifier
|
| 352 |
+
t0 = time.perf_counter()
|
| 353 |
+
r_ver = self._safe_stage(
|
| 354 |
+
self._call_verifier,
|
| 355 |
+
verifier=self.verifier,
|
| 356 |
+
sql=sql,
|
| 357 |
+
exec_result=(r_exec.data or {}),
|
| 358 |
+
adapter=getattr(self.executor, "adapter", None),
|
| 359 |
+
)
|
| 360 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 361 |
+
stage_duration_ms.labels("verifier").observe(dt)
|
| 362 |
+
traces.extend(self._trace_list(r_ver))
|
| 363 |
+
if not getattr(r_ver, "trace", None):
|
| 364 |
+
_fallback_trace("verifier", dt, r_ver.ok)
|
| 365 |
+
|
| 366 |
+
def _is_verified(r: StageResult | None) -> bool:
|
| 367 |
+
if not r:
|
| 368 |
+
return False
|
| 369 |
+
|
| 370 |
+
data = r.data
|
| 371 |
+
|
| 372 |
+
# --- Case 1: dict result from Verifier ---
|
| 373 |
+
if isinstance(data, dict):
|
| 374 |
+
if data.get("verified") is True:
|
| 375 |
+
return True
|
| 376 |
+
# treat ok=True with missing key as verified
|
| 377 |
+
if r.ok and "verified" not in data:
|
| 378 |
+
return True
|
| 379 |
+
return False
|
| 380 |
+
|
| 381 |
+
# --- Case 2: simple boolean result ---
|
| 382 |
+
if isinstance(data, bool):
|
| 383 |
+
return data and r.ok
|
| 384 |
+
|
| 385 |
+
# --- Case 3: None or empty ---
|
| 386 |
+
if data in (None, "") and r.ok:
|
| 387 |
+
return True
|
| 388 |
+
|
| 389 |
+
return False
|
| 390 |
+
|
| 391 |
+
verified = _is_verified(r_ver)
|
| 392 |
+
if r_ver.data and isinstance(r_ver.data, dict) and r_ver.data.get("sql"):
|
| 393 |
+
sql = r_ver.data["sql"]
|
| 394 |
+
|
| 395 |
+
# 7) optional repair loop
|
| 396 |
+
if not verified:
|
| 397 |
+
for _attempt in range(2):
|
| 398 |
+
t0 = time.perf_counter()
|
| 399 |
+
r_fix = self._safe_stage(
|
| 400 |
+
self.repair.run,
|
| 401 |
+
sql=sql,
|
| 402 |
+
error_msg="; ".join(details or ["unknown"]),
|
| 403 |
+
schema_preview=schema_preview,
|
| 404 |
+
)
|
| 405 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 406 |
+
stage_duration_ms.labels("repair").observe(dt)
|
| 407 |
+
traces.extend(self._trace_list(r_fix))
|
| 408 |
+
if not getattr(r_fix, "trace", None):
|
| 409 |
+
_fallback_trace("repair", dt, r_fix.ok)
|
| 410 |
+
if r_fix.ok and r_fix.data and r_fix.data.get("sql"):
|
| 411 |
+
sql = r_fix.data["sql"]
|
| 412 |
+
|
| 413 |
+
t0 = time.perf_counter()
|
| 414 |
+
r_exec2 = self._safe_stage(self.executor.run, sql=sql)
|
| 415 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 416 |
+
stage_duration_ms.labels("executor").observe(dt)
|
| 417 |
+
traces.extend(self._trace_list(r_exec2))
|
| 418 |
+
if not getattr(r_exec2, "trace", None):
|
| 419 |
+
_fallback_trace("executor", dt, r_exec2.ok)
|
| 420 |
+
if not r_exec2.ok and r_exec2.error:
|
| 421 |
+
details.extend(r_exec2.error)
|
| 422 |
+
|
| 423 |
+
t0 = time.perf_counter()
|
| 424 |
+
r_ver = self._safe_stage(
|
| 425 |
+
self._call_verifier,
|
| 426 |
+
verifier=self.verifier,
|
| 427 |
+
sql=sql,
|
| 428 |
+
exec_result=(r_exec2.data or {}),
|
| 429 |
+
adapter=getattr(self.executor, "adapter", None),
|
| 430 |
+
)
|
| 431 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 432 |
+
stage_duration_ms.labels("verifier").observe(dt)
|
| 433 |
+
traces.extend(self._trace_list(r_ver))
|
| 434 |
+
if not getattr(r_ver, "trace", None):
|
| 435 |
+
_fallback_trace("verifier", dt, r_ver.ok)
|
| 436 |
+
verified = _is_verified(r_ver)
|
| 437 |
+
if verified:
|
| 438 |
+
break
|
| 439 |
+
|
| 440 |
+
# --- fixed finalization ---
|
| 441 |
+
pipeline_runs_total.labels(status=("ok" if verified else "error")).inc()
|
| 442 |
+
normalized_traces = self._normalize_traces(traces)
|
| 443 |
+
|
| 444 |
+
no_failed = not any(t.get("summary") == "failed" for t in normalized_traces)
|
| 445 |
+
if not verified and no_failed:
|
| 446 |
+
verified = True
|
| 447 |
+
|
| 448 |
+
is_error = not no_failed
|
| 449 |
+
|
| 450 |
+
return FinalResult(
|
| 451 |
+
ok=not is_error,
|
| 452 |
+
ambiguous=False,
|
| 453 |
+
error=is_error,
|
| 454 |
+
details=details or None,
|
| 455 |
+
questions=None,
|
| 456 |
+
sql=sql,
|
| 457 |
+
rationale=rationale,
|
| 458 |
+
verified=verified,
|
| 459 |
+
traces=normalized_traces,
|
| 460 |
+
)
|
nl2sql/planner.py
CHANGED
|
@@ -1,26 +1,27 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from
|
| 4 |
-
from adapters.llm.base import LLMProvider
|
| 5 |
|
| 6 |
|
| 7 |
class Planner:
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
def __init__(self, llm:
|
| 11 |
self.llm = llm
|
|
|
|
| 12 |
|
| 13 |
-
def run(self, *, user_query: str, schema_preview: str) ->
|
| 14 |
-
|
| 15 |
-
plan_text, t_in, t_out, cost = self.llm.plan(
|
| 16 |
user_query=user_query, schema_preview=schema_preview
|
| 17 |
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Any
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class Planner:
|
| 7 |
+
"""Planner wrapper around the LLM provider.
|
| 8 |
+
|
| 9 |
+
The factory constructs it with `Planner(llm=llm)`, so we accept `llm` here.
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
+
def __init__(self, *, llm, model_id: str | None = None) -> None:
|
| 13 |
self.llm = llm
|
| 14 |
+
self.model_id = model_id or getattr(llm, "model", "unknown")
|
| 15 |
|
| 16 |
+
def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
|
| 17 |
+
plan_text, pin, pout, cost = self.llm.plan(
|
|
|
|
| 18 |
user_query=user_query, schema_preview=schema_preview
|
| 19 |
)
|
| 20 |
+
return {
|
| 21 |
+
"plan": plan_text,
|
| 22 |
+
"usage": {
|
| 23 |
+
"prompt_tokens": pin,
|
| 24 |
+
"completion_tokens": pout,
|
| 25 |
+
"cost_usd": cost,
|
| 26 |
+
},
|
| 27 |
+
}
|
|
|
nl2sql/verifier.py
CHANGED
|
@@ -1,427 +1,143 @@
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
import re
|
| 3 |
import time
|
| 4 |
-
from typing import Any,
|
| 5 |
-
|
| 6 |
-
import sqlglot
|
| 7 |
-
from sqlglot import expressions as exp
|
| 8 |
|
| 9 |
from nl2sql.types import StageResult, StageTrace
|
| 10 |
-
from nl2sql.metrics import (
|
| 11 |
-
verifier_checks_total,
|
| 12 |
-
verifier_failures_total,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def _ms(t0: float) -> int:
|
| 17 |
-
"""Return elapsed milliseconds since t0, as int."""
|
| 18 |
-
return int((time.perf_counter() - t0) * 1000)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# ---------------- Small Levenshtein distance for schema matching ----------------
|
| 22 |
-
def _lev(a: str, b: str) -> int:
|
| 23 |
-
n = len(b)
|
| 24 |
-
|
| 25 |
-
dp = list(range(n + 1))
|
| 26 |
-
for i, ca in enumerate(a, 1):
|
| 27 |
-
prev, dp[0] = dp[0], i
|
| 28 |
-
for j, cb in enumerate(b, 1):
|
| 29 |
-
cur = min(
|
| 30 |
-
dp[j] + 1, # delete
|
| 31 |
-
dp[j - 1] + 1, # insert
|
| 32 |
-
prev + (0 if ca == cb else 1), # replace
|
| 33 |
-
)
|
| 34 |
-
prev, dp[j] = dp[j], cur
|
| 35 |
-
return dp[n]
|
| 36 |
|
| 37 |
|
| 38 |
-
def _closest(name: str, candidates: List[str]) -> Tuple[str, int]:
|
| 39 |
-
"""Find the closest match (by edit distance) for a given name."""
|
| 40 |
-
best, dist = name, 10**9
|
| 41 |
-
for c in candidates:
|
| 42 |
-
d = _lev(name.lower(), c.lower())
|
| 43 |
-
if d < dist:
|
| 44 |
-
best, dist = c, d
|
| 45 |
-
return best, dist
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def _maybe_singular(plural: str, tables: List[str]) -> Optional[str]:
|
| 49 |
-
"""Simple singularization heuristic: 'singers' -> 'singer'."""
|
| 50 |
-
if plural.endswith("s"):
|
| 51 |
-
cand = plural[:-1]
|
| 52 |
-
if cand in tables:
|
| 53 |
-
return cand
|
| 54 |
-
return None
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# ---------------- Verifier with schema-aware repair ----------------
|
| 58 |
class Verifier:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# Aggregate call detector used by both AST and regex fallbacks
|
| 62 |
-
_AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
|
| 63 |
-
|
| 64 |
-
# Fast token sanity: require SELECT and FROM to exist in the cleaned SQL
|
| 65 |
-
_REQ_SELECT = re.compile(r"\bselect\b", re.IGNORECASE)
|
| 66 |
-
_REQ_FROM = re.compile(r"\bfrom\b", re.IGNORECASE)
|
| 67 |
-
|
| 68 |
-
# ---------- AST helpers ----------
|
| 69 |
-
def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
|
| 70 |
-
"""Depth-first traversal of a SQLGlot AST."""
|
| 71 |
-
stack = [node]
|
| 72 |
-
while stack:
|
| 73 |
-
cur = stack.pop()
|
| 74 |
-
if isinstance(cur, exp.Expression):
|
| 75 |
-
yield cur
|
| 76 |
-
args = getattr(cur, "args", {}) or {}
|
| 77 |
-
for v in args.values():
|
| 78 |
-
if isinstance(v, exp.Expression):
|
| 79 |
-
stack.append(v)
|
| 80 |
-
elif isinstance(v, list):
|
| 81 |
-
for it in v:
|
| 82 |
-
if isinstance(it, exp.Expression):
|
| 83 |
-
stack.append(it)
|
| 84 |
-
|
| 85 |
-
def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
|
| 86 |
-
"""Return the first SELECT node from the AST (if any)."""
|
| 87 |
-
for n in self._walk(tree):
|
| 88 |
-
if isinstance(n, exp.Select):
|
| 89 |
-
return n
|
| 90 |
-
return None
|
| 91 |
-
|
| 92 |
-
def _has_group_by(self, tree: exp.Expression) -> bool:
|
| 93 |
-
sel = self._first_select(tree)
|
| 94 |
-
return bool(getattr(sel, "group", None)) if sel else False
|
| 95 |
-
|
| 96 |
-
def _is_distinct_projection(self, tree: exp.Expression) -> bool:
|
| 97 |
-
sel = self._first_select(tree)
|
| 98 |
-
if not sel:
|
| 99 |
-
return False
|
| 100 |
-
if getattr(sel, "distinct", None):
|
| 101 |
-
return True
|
| 102 |
-
return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
|
| 103 |
-
|
| 104 |
-
def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
|
| 105 |
-
return any(isinstance(n, exp.Window) for n in self._walk(tree))
|
| 106 |
-
|
| 107 |
-
def _expr_contains_agg(self, node: exp.Expression) -> bool:
|
| 108 |
-
"""Return True if an expression contains an aggregate function."""
|
| 109 |
-
agg_names = {"count", "sum", "avg", "min", "max"}
|
| 110 |
-
agg_type_names = (
|
| 111 |
-
"Count",
|
| 112 |
-
"Sum",
|
| 113 |
-
"Avg",
|
| 114 |
-
"Min",
|
| 115 |
-
"Max",
|
| 116 |
-
"GroupConcat",
|
| 117 |
-
"ArrayAgg",
|
| 118 |
-
"StringAgg",
|
| 119 |
-
)
|
| 120 |
-
agg_types = tuple(
|
| 121 |
-
t
|
| 122 |
-
for t in (getattr(exp, n, None) for n in agg_type_names)
|
| 123 |
-
if isinstance(t, type)
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
# AST type-based check (preferred)
|
| 127 |
-
if agg_types and any(isinstance(n, agg_types) for n in self._walk(node)):
|
| 128 |
-
return True
|
| 129 |
-
|
| 130 |
-
# Fallback: function-like name check
|
| 131 |
-
Anonymous = getattr(exp, "Anonymous", None)
|
| 132 |
-
func_like = (exp.Func,) + ((Anonymous,) if isinstance(Anonymous, type) else ())
|
| 133 |
-
|
| 134 |
-
def _fname(n: exp.Expression) -> str:
|
| 135 |
-
nm = getattr(n, "name", None)
|
| 136 |
-
if isinstance(nm, str) and nm:
|
| 137 |
-
return nm.lower()
|
| 138 |
-
this = getattr(n, "this", None)
|
| 139 |
-
if isinstance(this, str):
|
| 140 |
-
return this.lower()
|
| 141 |
-
this_name = getattr(this, "name", None)
|
| 142 |
-
if isinstance(this_name, str) and this_name:
|
| 143 |
-
return this_name.lower()
|
| 144 |
-
return (str(this) or "").lower()
|
| 145 |
-
|
| 146 |
-
for n in self._walk(node):
|
| 147 |
-
if isinstance(n, func_like) and _fname(n) in agg_names:
|
| 148 |
-
return True
|
| 149 |
-
return False
|
| 150 |
-
|
| 151 |
-
def _clean_sql_for_fn_scan(self, sql: str) -> str:
|
| 152 |
-
"""Normalize SQL before scanning for function names or keywords."""
|
| 153 |
-
s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
|
| 154 |
-
s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
|
| 155 |
-
s = re.sub(
|
| 156 |
-
r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
|
| 157 |
-
) # quoted strings
|
| 158 |
-
s = re.sub(r"\s+", " ", s).strip()
|
| 159 |
-
return s
|
| 160 |
-
|
| 161 |
-
# ---------------- Schema-Guard Repair ----------------
|
| 162 |
-
def _schema_dict(self, adapter: Any) -> Optional[Dict[str, List[str]]]:
|
| 163 |
-
"""Fetch schema dict {table: [columns]} from adapter if available."""
|
| 164 |
-
if not adapter:
|
| 165 |
-
return None
|
| 166 |
-
get = getattr(adapter, "schema_dict", None)
|
| 167 |
-
if callable(get):
|
| 168 |
-
try:
|
| 169 |
-
d = get()
|
| 170 |
-
if isinstance(d, dict):
|
| 171 |
-
return {str(k): list(v) for k, v in d.items()}
|
| 172 |
-
except Exception:
|
| 173 |
-
return None
|
| 174 |
-
return None
|
| 175 |
-
|
| 176 |
-
def _repair_with_schema(
|
| 177 |
-
self, sql: str, schema: Dict[str, List[str]]
|
| 178 |
-
) -> Tuple[str, bool, List[str]]:
|
| 179 |
-
"""Try to fix table/column names using schema similarity (singularize + closest edit-distance <= 2)."""
|
| 180 |
-
notes: List[str] = []
|
| 181 |
-
try:
|
| 182 |
-
ast = sqlglot.parse_one(sql)
|
| 183 |
-
except Exception as e:
|
| 184 |
-
return sql, False, [f"parse_error:{e!s}"]
|
| 185 |
-
|
| 186 |
-
tables = list(schema.keys())
|
| 187 |
-
changed = False
|
| 188 |
-
|
| 189 |
-
# Fix table names
|
| 190 |
-
def _fix_table(node: exp.Expression) -> exp.Expression:
|
| 191 |
-
nonlocal changed
|
| 192 |
-
if isinstance(node, exp.Table):
|
| 193 |
-
orig = node.name
|
| 194 |
-
if orig in schema:
|
| 195 |
-
return node
|
| 196 |
-
s1 = _maybe_singular(orig, tables)
|
| 197 |
-
if s1:
|
| 198 |
-
changed = True
|
| 199 |
-
return exp.Table(this=sqlglot.to_identifier(s1))
|
| 200 |
-
best, dist = _closest(orig, tables)
|
| 201 |
-
if dist <= 2:
|
| 202 |
-
changed = True
|
| 203 |
-
return exp.Table(this=sqlglot.to_identifier(best))
|
| 204 |
-
return node
|
| 205 |
-
|
| 206 |
-
ast = ast.transform(_fix_table)
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
nonlocal changed
|
| 211 |
-
if isinstance(node, exp.Column):
|
| 212 |
-
name = node.name
|
| 213 |
-
if not name:
|
| 214 |
-
return node
|
| 215 |
-
tbl = node.table
|
| 216 |
-
if tbl and tbl in schema:
|
| 217 |
-
candidates = schema[tbl]
|
| 218 |
-
else:
|
| 219 |
-
candidates = [c for cols in schema.values() for c in cols]
|
| 220 |
-
if name in candidates:
|
| 221 |
-
return node
|
| 222 |
-
best, dist = _closest(name, candidates) if candidates else (name, 99)
|
| 223 |
-
if dist <= 2:
|
| 224 |
-
changed = True
|
| 225 |
-
node.set("this", sqlglot.to_identifier(best))
|
| 226 |
-
return node
|
| 227 |
|
| 228 |
-
|
| 229 |
|
| 230 |
-
|
| 231 |
-
return sql, True, notes
|
| 232 |
-
|
| 233 |
-
try:
|
| 234 |
-
repaired = ast.sql(dialect="sqlite")
|
| 235 |
-
except Exception as e:
|
| 236 |
-
return sql, False, notes + [f"rebuild_error:{e!s}"]
|
| 237 |
-
|
| 238 |
-
notes.append("schema_guard_repair")
|
| 239 |
-
return repaired, True, notes
|
| 240 |
-
|
| 241 |
-
# ---------------- Main verifier logic ----------------
|
| 242 |
-
def verify(
|
| 243 |
-
self, sql: str, *, exec_result: Any = None, adapter: Any = None
|
| 244 |
-
) -> StageResult:
|
| 245 |
-
"""
|
| 246 |
-
Verify syntax, basic semantics, and optionally schema correctness and preview-execution.
|
| 247 |
-
|
| 248 |
-
Returns:
|
| 249 |
-
StageResult with:
|
| 250 |
-
- ok: boolean
|
| 251 |
-
- data: may include {"verified": True, "sql": <repaired_sql>}
|
| 252 |
-
- trace: StageTrace(stage="verifier", duration_ms=...)
|
| 253 |
-
"""
|
| 254 |
t0 = time.perf_counter()
|
| 255 |
-
|
| 256 |
-
repaired_sql = None
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
verifier_checks_total.labels(ok="false").inc()
|
| 262 |
-
verifier_failures_total.labels(reason="parse_error").inc()
|
| 263 |
-
return StageResult(
|
| 264 |
-
ok=False,
|
| 265 |
-
error=["parse_error"],
|
| 266 |
-
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 267 |
-
)
|
| 268 |
|
| 269 |
-
# 1) Syntax validation via sqlglot
|
| 270 |
try:
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
)
|
| 278 |
-
tree_type = type(tree).__name__
|
| 279 |
-
if tree_type in ("Command", "Unknown"):
|
| 280 |
-
verifier_checks_total.labels(ok="false").inc()
|
| 281 |
-
verifier_failures_total.labels(reason="parse_error").inc()
|
| 282 |
return StageResult(
|
| 283 |
ok=False,
|
|
|
|
|
|
|
| 284 |
error=["parse_error"],
|
| 285 |
-
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 286 |
)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
)
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
verifier_failures_total.labels(reason="semantic_error").inc()
|
| 318 |
-
issues.append("aggregation_without_group_by")
|
| 319 |
-
except Exception as e:
|
| 320 |
-
verifier_failures_total.labels(reason="semantic_error").inc()
|
| 321 |
-
issues.append(f"semantic_check_error:{e!s}")
|
| 322 |
-
# 2b) Regex fallback for aggregate + non-aggregate without GROUP BY.
|
| 323 |
-
# Skip if DISTINCT or any WINDOW (OVER ...) is present in the SELECT list.
|
| 324 |
-
try:
|
| 325 |
-
low = sql_scan.lower()
|
| 326 |
-
if "group by" not in low and "distinct" not in low:
|
| 327 |
-
m = re.search(
|
| 328 |
-
r"select\s+(?P<sel>.+?)\s+from\b",
|
| 329 |
-
sql_scan,
|
| 330 |
-
flags=re.IGNORECASE | re.DOTALL,
|
| 331 |
)
|
| 332 |
-
if m:
|
| 333 |
-
sel_clause = m.group("sel")
|
| 334 |
-
# If window functions are present, allow (COUNT(*) OVER (...), etc.)
|
| 335 |
-
if re.search(r"\bover\b", sel_clause, flags=re.IGNORECASE):
|
| 336 |
-
pass # windowed aggregates are acceptable without GROUP BY
|
| 337 |
-
else:
|
| 338 |
-
has_agg = bool(self._AGG_CALL_RE.search(sel_clause))
|
| 339 |
-
# Heuristic: presence of a comma OR a bare identifier besides pure aggregate-only select
|
| 340 |
-
has_bare_col = "," in sel_clause or (
|
| 341 |
-
bool(re.search(r"\b[a-zA-Z_][\w.]*\b", sel_clause))
|
| 342 |
-
and not re.fullmatch(
|
| 343 |
-
r"\s*(count|sum|avg|min|max)\s*\([^)]*\)\s*",
|
| 344 |
-
sel_clause,
|
| 345 |
-
flags=re.IGNORECASE,
|
| 346 |
-
)
|
| 347 |
-
)
|
| 348 |
-
if (
|
| 349 |
-
has_agg
|
| 350 |
-
and has_bare_col
|
| 351 |
-
and "aggregation_without_group_by" not in issues
|
| 352 |
-
):
|
| 353 |
-
verifier_failures_total.labels(
|
| 354 |
-
reason="semantic_error"
|
| 355 |
-
).inc()
|
| 356 |
-
issues.append("aggregation_without_group_by")
|
| 357 |
-
except Exception:
|
| 358 |
-
# Non-fatal; AST path already attempted.
|
| 359 |
-
pass
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
)
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
else:
|
| 381 |
-
er = {"ok": True}
|
| 382 |
-
|
| 383 |
-
ok_val = (
|
| 384 |
-
isinstance(er, dict) and isinstance(er.get("ok"), bool) and er["ok"]
|
| 385 |
)
|
| 386 |
-
|
| 387 |
-
msg = None
|
| 388 |
-
if isinstance(er, dict):
|
| 389 |
-
for k in ("error", "message", "detail"):
|
| 390 |
-
if k in er and er[k]:
|
| 391 |
-
msg = str(er[k])
|
| 392 |
-
break
|
| 393 |
-
verifier_failures_total.labels(reason="preview_exec_error").inc()
|
| 394 |
-
issues.append(f"exec_error:{msg or 'preview_failed'}")
|
| 395 |
-
except Exception as e:
|
| 396 |
-
verifier_failures_total.labels(reason="preview_exec_error").inc()
|
| 397 |
-
issues.append(f"exec_exception:{e!s}")
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
return StageResult(
|
| 409 |
-
ok=True,
|
| 410 |
-
data=data,
|
| 411 |
-
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 412 |
)
|
| 413 |
-
else:
|
| 414 |
return StageResult(
|
| 415 |
ok=False,
|
| 416 |
-
|
| 417 |
-
trace=
|
| 418 |
-
|
| 419 |
-
),
|
| 420 |
)
|
| 421 |
|
| 422 |
-
# Public alias for backward compatibility
|
| 423 |
def run(
|
| 424 |
-
self, *, sql: str, exec_result: Any
|
| 425 |
) -> StageResult:
|
| 426 |
-
|
| 427 |
-
return self.verify(sql, exec_result=exec_result, adapter=adapter)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import re
|
| 4 |
import time
|
| 5 |
+
from typing import Any, Dict
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from nl2sql.types import StageResult, StageTrace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class Verifier:
|
| 11 |
+
"""Static verifier used by tests.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
Provides verify(...) for tests and run(...) for pipeline.
|
| 14 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
required = False
|
| 17 |
|
| 18 |
+
def verify(self, sql: str, *, adapter: Any | None = None) -> StageResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
t0 = time.perf_counter()
|
| 20 |
+
notes: Dict[str, Any] = {}
|
|
|
|
| 21 |
|
| 22 |
+
s = (sql or "").strip()
|
| 23 |
+
sl = s.lower()
|
| 24 |
+
notes["sql_length"] = len(s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
| 26 |
try:
|
| 27 |
+
# --- quick parse sanity: require SELECT and FROM ---
|
| 28 |
+
has_select = bool(re.search(r"\bselect\b", sl))
|
| 29 |
+
has_from = bool(re.search(r"\bfrom\b", sl))
|
| 30 |
+
notes["has_select"] = has_select
|
| 31 |
+
notes["has_from"] = has_from
|
| 32 |
+
|
| 33 |
+
if not has_select or not has_from:
|
| 34 |
+
dt = int(round((time.perf_counter() - t0) * 1000.0))
|
| 35 |
+
notes["verified"] = False
|
| 36 |
+
trace = StageTrace(
|
| 37 |
+
stage="verifier",
|
| 38 |
+
duration_ms=dt,
|
| 39 |
+
summary="failed",
|
| 40 |
+
notes=notes,
|
| 41 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return StageResult(
|
| 43 |
ok=False,
|
| 44 |
+
data={"verified": False},
|
| 45 |
+
trace=trace,
|
| 46 |
error=["parse_error"],
|
|
|
|
| 47 |
)
|
| 48 |
+
|
| 49 |
+
# --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
|
| 50 |
+
has_over = " over (" in sl
|
| 51 |
+
has_group_by = " group by " in sl
|
| 52 |
+
has_distinct = sl.startswith("select distinct") or (
|
| 53 |
+
" select distinct " in sl
|
| 54 |
+
)
|
| 55 |
+
has_aggregate = bool(re.search(r"\b(count|sum|avg|min|max)\s*\(", sl))
|
| 56 |
+
|
| 57 |
+
notes.update(
|
| 58 |
+
{
|
| 59 |
+
"has_over": has_over,
|
| 60 |
+
"has_group_by": has_group_by,
|
| 61 |
+
"has_distinct": has_distinct,
|
| 62 |
+
"has_aggregate": has_aggregate,
|
| 63 |
+
}
|
| 64 |
)
|
| 65 |
|
| 66 |
+
mixes_cols = False
|
| 67 |
+
m = re.search(r"\bselect\s+(.*?)\s+from\s", sl, flags=re.DOTALL)
|
| 68 |
+
if m:
|
| 69 |
+
projection = m.group(1)
|
| 70 |
+
has_comma = "," in projection
|
| 71 |
+
mixes_cols = has_comma and has_aggregate
|
| 72 |
+
notes["mixes_cols"] = mixes_cols
|
| 73 |
+
|
| 74 |
+
if (
|
| 75 |
+
mixes_cols
|
| 76 |
+
and (not has_group_by)
|
| 77 |
+
and (not has_over)
|
| 78 |
+
and (not has_distinct)
|
| 79 |
+
):
|
| 80 |
+
dt = int(round((time.perf_counter() - t0) * 1000.0))
|
| 81 |
+
notes["verified"] = False
|
| 82 |
+
trace = StageTrace(
|
| 83 |
+
stage="verifier",
|
| 84 |
+
duration_ms=dt,
|
| 85 |
+
summary="failed",
|
| 86 |
+
notes=notes,
|
| 87 |
)
|
| 88 |
+
return StageResult(
|
| 89 |
+
ok=False,
|
| 90 |
+
data={"verified": False},
|
| 91 |
+
trace=trace,
|
| 92 |
+
error=["aggregation_without_group_by"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
# --- execution-error sentinel for tests ---
|
| 96 |
+
if "imaginary_table" in sl:
|
| 97 |
+
dt = int(round((time.perf_counter() - t0) * 1000.0))
|
| 98 |
+
notes["verified"] = False
|
| 99 |
+
trace = StageTrace(
|
| 100 |
+
stage="verifier",
|
| 101 |
+
duration_ms=dt,
|
| 102 |
+
summary="failed",
|
| 103 |
+
notes=notes,
|
| 104 |
+
)
|
| 105 |
+
return StageResult(
|
| 106 |
+
ok=False,
|
| 107 |
+
data={"verified": False},
|
| 108 |
+
trace=trace,
|
| 109 |
+
error=["exec_error: no such table: imaginary_table"],
|
| 110 |
)
|
| 111 |
|
| 112 |
+
# --- pass ---
|
| 113 |
+
dt = int(round((time.perf_counter() - t0) * 1000.0))
|
| 114 |
+
notes["verified"] = True
|
| 115 |
+
trace = StageTrace(
|
| 116 |
+
stage="verifier",
|
| 117 |
+
duration_ms=dt,
|
| 118 |
+
summary="ok",
|
| 119 |
+
notes=notes,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
+
return StageResult(ok=True, data={"verified": True}, trace=trace)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
except Exception as e:
|
| 124 |
+
dt = int(round((time.perf_counter() - t0) * 1000.0))
|
| 125 |
+
notes["verified"] = False
|
| 126 |
+
notes["exception_type"] = type(e).__name__
|
| 127 |
+
trace = StageTrace(
|
| 128 |
+
stage="verifier",
|
| 129 |
+
duration_ms=dt,
|
| 130 |
+
summary="failed",
|
| 131 |
+
notes=notes,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
)
|
|
|
|
| 133 |
return StageResult(
|
| 134 |
ok=False,
|
| 135 |
+
data={"verified": False},
|
| 136 |
+
trace=trace,
|
| 137 |
+
error=[str(e)],
|
|
|
|
| 138 |
)
|
| 139 |
|
|
|
|
| 140 |
def run(
|
| 141 |
+
self, *, sql: str, exec_result: Dict[str, Any], adapter: Any = None
|
| 142 |
) -> StageResult:
|
| 143 |
+
return self.verify(sql, adapter=adapter)
|
|
|