Melika Kheirieh commited on
Commit
e3e0ac5
·
1 Parent(s): 3716701

refactor(core): trace schema upgrade, verifier/executor sync, benchmark plot polish

Browse files
adapters/llm/base.py CHANGED
@@ -3,7 +3,7 @@ from typing import Tuple, Dict, Any, Protocol
3
 
4
 
5
  class LLMProvider(Protocol):
6
- provider_id: str
7
 
8
  def plan(
9
  self, *, user_query: str, schema_preview: str
 
3
 
4
 
5
  class LLMProvider(Protocol):
6
+ PROVIDER_ID: str
7
 
8
  def plan(
9
  self, *, user_query: str, schema_preview: str
adapters/llm/openai_provider.py CHANGED
@@ -37,7 +37,11 @@ def _resolve_api_config() -> tuple[str, str, str]:
37
  class OpenAIProvider(LLMProvider):
38
  """OpenAI LLM provider implementation."""
39
 
40
- provider_id = "openai"
 
 
 
 
41
 
42
  def __init__(self) -> None:
43
  """Initialize OpenAI client with config from environment."""
@@ -46,6 +50,8 @@ class OpenAIProvider(LLMProvider):
46
  os.environ["OPENAI_BASE_URL"] = base_url
47
  self.client = OpenAI()
48
  self.model = model
 
 
49
 
50
  def plan(
51
  self, *, user_query: str, schema_preview: str
@@ -94,8 +100,20 @@ Create a step-by-step plan to answer this question with SQL."""
94
  prompt_tokens = usage.prompt_tokens
95
  completion_tokens = usage.completion_tokens
96
  cost = self._estimate_cost(usage)
 
 
 
 
 
 
97
  return (msg, prompt_tokens, completion_tokens, cost)
98
  else:
 
 
 
 
 
 
99
  return (msg, 0, 0, 0.0)
100
 
101
  def generate_sql(
@@ -197,12 +215,27 @@ Now generate the SQL for the given question:"""
197
  if not sql:
198
  raise ValueError("LLM returned empty 'sql'")
199
 
 
200
  if usage:
201
  prompt_tokens = usage.prompt_tokens
202
  completion_tokens = usage.completion_tokens
203
  cost = self._estimate_cost(usage)
 
 
 
 
 
 
 
204
  return (sql, rationale, prompt_tokens, completion_tokens, cost)
205
  else:
 
 
 
 
 
 
 
206
  return (sql, rationale, 0, 0, 0.0)
207
 
208
  def _simplify_sql(self, sql: str) -> str:
@@ -307,8 +340,22 @@ Return the corrected SQL (keep it simple):"""
307
  prompt_tokens = usage.prompt_tokens
308
  completion_tokens = usage.completion_tokens
309
  cost = self._estimate_cost(usage)
 
 
 
 
 
 
 
310
  return (fixed_sql, prompt_tokens, completion_tokens, cost)
311
  else:
 
 
 
 
 
 
 
312
  return (fixed_sql, 0, 0, 0.0)
313
 
314
  def _estimate_cost(self, usage: Any) -> float:
 
37
  class OpenAIProvider(LLMProvider):
38
  """OpenAI LLM provider implementation."""
39
 
40
+ PROVIDER_ID = "openai"
41
+
42
+ def get_last_usage(self) -> dict[str, Any]:
43
+ """Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
44
+ return dict(self._last_usage)
45
 
46
  def __init__(self) -> None:
47
  """Initialize OpenAI client with config from environment."""
 
50
  os.environ["OPENAI_BASE_URL"] = base_url
51
  self.client = OpenAI()
52
  self.model = model
53
+ # last call usage/metadata for tracing
54
+ self._last_usage: dict[str, Any] = {}
55
 
56
  def plan(
57
  self, *, user_query: str, schema_preview: str
 
100
  prompt_tokens = usage.prompt_tokens
101
  completion_tokens = usage.completion_tokens
102
  cost = self._estimate_cost(usage)
103
+ self._last_usage = {
104
+ "kind": "plan",
105
+ "prompt_tokens": prompt_tokens,
106
+ "completion_tokens": completion_tokens,
107
+ "cost_usd": cost,
108
+ }
109
  return (msg, prompt_tokens, completion_tokens, cost)
110
  else:
111
+ self._last_usage = {
112
+ "kind": "plan",
113
+ "prompt_tokens": 0,
114
+ "completion_tokens": 0,
115
+ "cost_usd": 0.0,
116
+ }
117
  return (msg, 0, 0, 0.0)
118
 
119
  def generate_sql(
 
215
  if not sql:
216
  raise ValueError("LLM returned empty 'sql'")
217
 
218
+ sql_length = len(sql)
219
  if usage:
220
  prompt_tokens = usage.prompt_tokens
221
  completion_tokens = usage.completion_tokens
222
  cost = self._estimate_cost(usage)
223
+ self._last_usage = {
224
+ "kind": "generate",
225
+ "prompt_tokens": prompt_tokens,
226
+ "completion_tokens": completion_tokens,
227
+ "cost_usd": cost,
228
+ "sql_length": sql_length,
229
+ }
230
  return (sql, rationale, prompt_tokens, completion_tokens, cost)
231
  else:
232
+ self._last_usage = {
233
+ "kind": "generate",
234
+ "prompt_tokens": 0,
235
+ "completion_tokens": 0,
236
+ "cost_usd": 0.0,
237
+ "sql_length": sql_length,
238
+ }
239
  return (sql, rationale, 0, 0, 0.0)
240
 
241
  def _simplify_sql(self, sql: str) -> str:
 
340
  prompt_tokens = usage.prompt_tokens
341
  completion_tokens = usage.completion_tokens
342
  cost = self._estimate_cost(usage)
343
+ self._last_usage = {
344
+ "kind": "repair",
345
+ "prompt_tokens": prompt_tokens,
346
+ "completion_tokens": completion_tokens,
347
+ "cost_usd": cost,
348
+ "sql_length": len(fixed_sql),
349
+ }
350
  return (fixed_sql, prompt_tokens, completion_tokens, cost)
351
  else:
352
+ self._last_usage = {
353
+ "kind": "repair",
354
+ "prompt_tokens": 0,
355
+ "completion_tokens": 0,
356
+ "cost_usd": 0.0,
357
+ "sql_length": len(fixed_sql),
358
+ }
359
  return (fixed_sql, 0, 0, 0.0)
360
 
361
  def _estimate_cost(self, usage: Any) -> float:
benchmarks/evaluate_spider_pro.py CHANGED
@@ -206,6 +206,45 @@ def evaluate_sql(pred: str, gold: str, db: Path) -> Dict[str, float]:
206
  return {"em": em, "sm": sm, "exec": exec_acc}
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # ---------------------- Dataclass + runner ------------------
210
 
211
 
 
206
  return {"em": em, "sm": sm, "exec": exec_acc}
207
 
208
 
209
+ # ---------------------- Trace flatten helpers -------------------
210
+ def _flatten_trace_entry(d: Dict[str, Any]) -> Dict[str, Any]:
211
+ out = dict(d or {})
212
+ notes = out.pop("notes", {}) or {}
213
+ # promote selected keys to top-level for easier analysis
214
+ for k in (
215
+ "tokens_in",
216
+ "tokens_out",
217
+ "cost_usd",
218
+ "sql_length",
219
+ "row_count",
220
+ "verified",
221
+ "error_type",
222
+ "repair_attempts",
223
+ "skipped",
224
+ "col_count",
225
+ ):
226
+ if k in notes:
227
+ out[k] = notes[k]
228
+ if notes:
229
+ out["notes"] = notes
230
+ return out
231
+
232
+
233
+ def _per_stage_ms(trace_list: List[Dict[str, Any]]) -> Dict[str, float]:
234
+ acc = {s: 0.0 for s in STAGES}
235
+ cnt = {s: 0 for s in STAGES}
236
+ for t in trace_list:
237
+ s = t.get("stage")
238
+ if s in acc:
239
+ ms = t.get("duration_ms", t.get("ms", 0.0))
240
+ try:
241
+ acc[s] += float(ms)
242
+ cnt[s] += 1
243
+ except Exception:
244
+ pass
245
+ return {s: round(acc[s] / cnt[s], 2) if cnt[s] else 0.0 for s in STAGES}
246
+
247
+
248
  # ---------------------- Dataclass + runner ------------------
249
 
250
 
benchmarks/plot_results.py CHANGED
@@ -124,6 +124,35 @@ def plot_latency_per_stage(run: Path, summary: dict, rows: list[dict]) -> None:
124
  plt.close()
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def main() -> None:
128
  run = _latest_run_dir()
129
  print(f"📂 Using latest run: {run.name}")
@@ -132,8 +161,9 @@ def main() -> None:
132
  plot_metrics_overview(run, summary)
133
  plot_latency_hist(run, rows)
134
  plot_latency_per_stage(run, summary, rows)
 
135
  print(
136
- "✅ Saved: metrics_overview.png, latency_histogram.png, latency_per_stage.png"
137
  )
138
 
139
 
 
124
  plt.close()
125
 
126
 
127
+ def plot_errors_overview(run: Path) -> None:
128
+ p = run / "trace.jsonl"
129
+ if not p.exists():
130
+ return
131
+ from collections import Counter
132
+
133
+ counts: Counter[str] = Counter()
134
+ with p.open("r", encoding="utf-8") as f:
135
+ for line in f:
136
+ try:
137
+ obj = json.loads(line)
138
+ except Exception:
139
+ continue
140
+ for t in obj.get("trace", []):
141
+ et = t.get("error_type")
142
+ if et:
143
+ counts[et] += 1
144
+ if not counts:
145
+ return
146
+ labels, values = zip(*sorted(counts.items(), key=lambda x: x[1], reverse=True))
147
+ plt.figure(figsize=(9, 4))
148
+ plt.bar(labels, values)
149
+ plt.title("Errors by Type")
150
+ plt.ylabel("Count")
151
+ plt.tight_layout()
152
+ plt.savefig(run / "errors_overview.png")
153
+ plt.close()
154
+
155
+
156
  def main() -> None:
157
  run = _latest_run_dir()
158
  print(f"📂 Using latest run: {run.name}")
 
161
  plot_metrics_overview(run, summary)
162
  plot_latency_hist(run, rows)
163
  plot_latency_per_stage(run, summary, rows)
164
+ plot_errors_overview(run)
165
  print(
166
+ "✅ Saved: metrics_overview.png, latency_histogram.png, latency_per_stage.png, errors_overview.png"
167
  )
168
 
169
 
benchmarks/results_pro/20251109-105728/eval.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 11836, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 6838, "summary": "ok", "notes": {"len_plan": 1460}, "token_in": 265, "token_out": 356, "cost_usd": 0.00025334999999999995, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3409, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 838, "token_out": 19, "cost_usd": 0.0001371, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 832, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 744, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 10414, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 5346, "summary": "ok", "notes": {"len_plan": 1385}, "token_in": 266, "token_out": 334, "cost_usd": 0.00024029999999999999, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3352, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 817, "token_out": 19, "cost_usd": 0.00013394999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 4, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 871, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 831, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}, "skipped": false}]}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 13807, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 8248, "summary": "ok", "notes": {"len_plan": 1415}, "token_in": 276, "token_out": 335, "cost_usd": 0.0002424, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3686, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 828, "token_out": 37, "cost_usd": 0.00014639999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 55}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 960, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64, "attempt": 1}, "token_in": 320, "token_out": 21, "cost_usd": 6.0599999999999996e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 901, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64, "attempt": 2}, "token_in": 323, "token_out": 21, "cost_usd": 6.104999999999999e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 13396, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7141, "summary": "ok", "notes": {"len_plan": 1569}, "token_in": 274, "token_out": 404, "cost_usd": 0.0002835, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 4139, "summary": "ok", "notes": {"rationale_len": 87}, "token_in": 895, "token_out": 46, "cost_usd": 0.00016184999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 937, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80, "attempt": 1}, "token_in": 328, "token_out": 24, "cost_usd": 6.36e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 80}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1160, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72, "attempt": 2}, "token_in": 332, "token_out": 21, "cost_usd": 6.24e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
benchmarks/results_pro/20251109-105728/latency_histogram.png ADDED
benchmarks/results_pro/20251109-105728/latency_per_stage.png ADDED
benchmarks/results_pro/20251109-105728/metrics_overview.png ADDED
benchmarks/results_pro/20251109-105728/results.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ db_id,query,ok,em,sm,exec_acc,latency_ms
2
+ concert_singer,"How many singers do we have?",✅,1.0,1.0,1.0,11836
3
+ concert_singer,"What is the total number of singers?",✅,1.0,1.0,1.0,10414
4
+ concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",✅,0.0,0.0,0.0,0
5
+ concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",✅,0.0,1.0,1.0,13807
6
+ concert_singer,"What is the average, minimum, and maximum age of all singers from France?",✅,0.0,1.0,1.0,13396
benchmarks/results_pro/20251109-105728/summary.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-11-09T10:58:17",
3
+ "split": "dev",
4
+ "config": "configs/sqlite_pipeline.yaml",
5
+ "total": 5,
6
+ "success": 5,
7
+ "success_rate": 1.0,
8
+ "avg_latency_ms": 9890.6,
9
+ "p50_latency_ms": 11836.0,
10
+ "p95_latency_ms": 13724.8,
11
+ "EM": 0.4,
12
+ "SM": 0.8,
13
+ "ExecAcc": 0.8,
14
+ "detector_avg_ms": 0.0,
15
+ "planner_avg_ms": 6893.25,
16
+ "generator_avg_ms": 3646.5,
17
+ "safety_avg_ms": 1.67,
18
+ "executor_avg_ms": 1.33,
19
+ "verifier_avg_ms": 0.42,
20
+ "repair_avg_ms": 904.5
21
+ }
benchmarks/results_pro/20251109-110149/eval.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 12419, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7389, "summary": "ok", "notes": {"len_plan": 1297}, "token_in": 265, "token_out": 305, "cost_usd": 0.00022274999999999997, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3333, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 787, "token_out": 19, "cost_usd": 0.00012945, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 812, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 867, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 13653, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 8492, "summary": "ok", "notes": {"len_plan": 1444}, "token_in": 266, "token_out": 343, "cost_usd": 0.0002457, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3127, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 826, "token_out": 19, "cost_usd": 0.00013529999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 914, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35, "attempt": 1}, "token_in": 313, "token_out": 8, "cost_usd": 5.175e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1108, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35, "attempt": 2}, "token_in": 316, "token_out": 8, "cost_usd": 5.2199999999999995e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 35}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}, "skipped": false}]}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 12306, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 6684, "summary": "ok", "notes": {"len_plan": 1253}, "token_in": 276, "token_out": 287, "cost_usd": 0.0002136, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 3456, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 780, "token_out": 37, "cost_usd": 0.0001392, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 55}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 911, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64, "attempt": 1}, "token_in": 320, "token_out": 21, "cost_usd": 6.0599999999999996e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1239, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64, "attempt": 2}, "token_in": 323, "token_out": 21, "cost_usd": 6.104999999999999e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 64}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 14824, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 9466, "summary": "ok", "notes": {"len_plan": 1418}, "token_in": 274, "token_out": 352, "cost_usd": 0.00025229999999999995, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "generator", "duration_ms": 2949, "summary": "ok", "notes": {"rationale_len": 87}, "token_in": 843, "token_out": 46, "cost_usd": 0.00015404999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1139, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80, "attempt": 1}, "token_in": 328, "token_out": 24, "cost_usd": 6.36e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 80}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1250, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72, "attempt": 2}, "token_in": 332, "token_out": 21, "cost_usd": 6.24e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}, "skipped": false}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}, "skipped": false}]}
benchmarks/results_pro/20251109-110149/latency_histogram.png ADDED
benchmarks/results_pro/20251109-110149/latency_per_stage.png ADDED
benchmarks/results_pro/20251109-110149/metrics_overview.png ADDED
benchmarks/results_pro/20251109-110149/results.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ db_id,query,ok,em,sm,exec_acc,latency_ms
2
+ concert_singer,"How many singers do we have?",✅,1.0,1.0,1.0,12419
3
+ concert_singer,"What is the total number of singers?",✅,1.0,1.0,1.0,13653
4
+ concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",✅,0.0,0.0,0.0,0
5
+ concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",✅,0.0,1.0,1.0,12306
6
+ concert_singer,"What is the average, minimum, and maximum age of all singers from France?",✅,0.0,1.0,1.0,14824
benchmarks/results_pro/20251109-110149/summary.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-11-09T11:02:43",
3
+ "split": "dev",
4
+ "config": "configs/sqlite_pipeline.yaml",
5
+ "total": 5,
6
+ "success": 5,
7
+ "success_rate": 1.0,
8
+ "avg_latency_ms": 10640.4,
9
+ "p50_latency_ms": 12419.0,
10
+ "p95_latency_ms": 14589.8,
11
+ "EM": 0.4,
12
+ "SM": 0.8,
13
+ "ExecAcc": 0.8,
14
+ "detector_avg_ms": 0.0,
15
+ "planner_avg_ms": 8007.75,
16
+ "generator_avg_ms": 3216.25,
17
+ "safety_avg_ms": 2.0,
18
+ "executor_avg_ms": 1.25,
19
+ "verifier_avg_ms": 0.58,
20
+ "repair_avg_ms": 1030.0
21
+ }
nl2sql/executor.py CHANGED
@@ -16,7 +16,11 @@ class Executor:
16
  trace = StageTrace(
17
  stage=self.name,
18
  duration_ms=(time.perf_counter() - t0) * 1000,
19
- notes={"row_count": len(rows), "col_count": len(cols)},
 
 
 
 
20
  )
21
  return StageResult(
22
  ok=True, data={"rows": rows, "columns": cols}, trace=trace
@@ -25,6 +29,10 @@ class Executor:
25
  trace = StageTrace(
26
  stage=self.name,
27
  duration_ms=(time.perf_counter() - t0) * 1000,
28
- notes={"error": str(e)},
 
 
 
 
29
  )
30
  return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
 
16
  trace = StageTrace(
17
  stage=self.name,
18
  duration_ms=(time.perf_counter() - t0) * 1000,
19
+ notes={
20
+ "row_count": len(rows),
21
+ "col_count": len(cols),
22
+ "sql_length": len(sql or ""),
23
+ },
24
  )
25
  return StageResult(
26
  ok=True, data={"rows": rows, "columns": cols}, trace=trace
 
29
  trace = StageTrace(
30
  stage=self.name,
31
  duration_ms=(time.perf_counter() - t0) * 1000,
32
+ notes={
33
+ "error": str(e),
34
+ "error_type": type(e).__name__,
35
+ "sql_length": len(sql or ""),
36
+ },
37
  )
38
  return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
nl2sql/pipeline.py CHANGED
@@ -1,8 +1,12 @@
 
1
  from __future__ import annotations
 
 
2
  import traceback
 
3
  from dataclasses import dataclass
4
- from typing import Dict, Any, Optional, List
5
- import time
6
 
7
  from nl2sql.types import StageResult
8
  from nl2sql.ambiguity_detector import AmbiguityDetector
@@ -14,6 +18,7 @@ from nl2sql.verifier import Verifier
14
  from nl2sql.repair import Repair
15
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
16
  from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
 
17
 
18
 
19
  @dataclass(frozen=True)
@@ -32,7 +37,7 @@ class FinalResult:
32
  class Pipeline:
33
  """
34
  NL2SQL Copilot pipeline:
35
- detector planner generator safety executor verifier (optional repair loop).
36
  """
37
 
38
  def __init__(
@@ -53,7 +58,6 @@ class Pipeline:
53
  self.executor = executor or NoOpExecutor()
54
  self.verifier = verifier or NoOpVerifier()
55
  self.repair = repair or NoOpRepair()
56
- # If the verifier explicitly requires verification, enforce it in finalize.
57
  self.require_verification = bool(getattr(self.verifier, "required", False))
58
 
59
  # ---------------------------- helpers ----------------------------
@@ -95,9 +99,19 @@ class Pipeline:
95
  except Exception:
96
  dur_int = 0
97
  notes = t.get("notes") or {}
98
- summary = t.get("summary") or (
99
- "failed" if (notes.get("error") or notes.get("errors")) else "ok"
100
- )
 
 
 
 
 
 
 
 
 
 
101
  payload = {
102
  "stage": stage,
103
  "duration_ms": dur_int,
@@ -125,12 +139,59 @@ class Pipeline:
125
  try:
126
  r = fn(**kwargs)
127
  if isinstance(r, StageResult):
 
 
 
 
 
 
 
128
  return r
129
  return StageResult(ok=True, data=r, trace=None)
130
  except Exception as e:
131
  tb = traceback.format_exc()
132
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # ------------------------------ run ------------------------------
135
  def run(
136
  self,
@@ -139,329 +200,261 @@ class Pipeline:
139
  schema_preview: str | None = None,
140
  clarify_answers: Optional[Dict[str, Any]] = None,
141
  ) -> FinalResult:
142
- t_all0 = time.perf_counter()
143
  traces: List[dict] = []
144
  details: List[str] = []
145
-
146
- def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
147
- traces.append(
148
- self._mk_trace(
149
- stage=stage_name,
150
- duration_ms=dt_ms,
151
- summary=("ok" if ok else "failed"),
152
- )
153
- )
154
-
155
  schema_preview = schema_preview or ""
156
  clarify_answers = clarify_answers or {}
157
 
158
- try:
159
- # --- 1) detector ---
160
- t0 = time.perf_counter()
161
- questions = self.detector.detect(user_query, schema_preview)
162
- dt = (time.perf_counter() - t0) * 1000.0
163
- is_amb = bool(questions)
164
- stage_duration_ms.labels("detector").observe(dt)
165
  traces.append(
166
- self._mk_trace(
167
- stage="detector",
168
- duration_ms=dt,
169
- summary=("ambiguous" if is_amb else "clear"),
170
- notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
171
- )
172
  )
173
- if questions:
174
- pipeline_runs_total.labels(status="ambiguous").inc()
175
- return FinalResult(
176
- ok=True,
177
- ambiguous=True,
178
- error=False,
179
- details=[f"Ambiguities found: {len(questions)}"],
180
- questions=questions,
181
- sql=None,
182
- rationale=None,
183
- verified=None,
184
- traces=self._normalize_traces(traces),
185
- )
186
 
187
- # --- 2) planner ---
188
- t0 = time.perf_counter()
189
- r_plan = self._safe_stage(
190
- self.planner.run, user_query=user_query, schema_preview=schema_preview
 
 
 
 
 
 
 
 
191
  )
192
- dt = (time.perf_counter() - t0) * 1000.0
193
- stage_duration_ms.labels("planner").observe(dt)
194
- traces.extend(self._trace_list(r_plan))
195
- if not getattr(r_plan, "trace", None):
196
- _fallback_trace("planner", dt, r_plan.ok)
197
- if not r_plan.ok:
198
- pipeline_runs_total.labels(status="error").inc()
199
- return FinalResult(
200
- ok=False,
201
- ambiguous=False,
202
- error=True,
203
- details=r_plan.error,
204
- questions=None,
205
- sql=None,
206
- rationale=None,
207
- verified=None,
208
- traces=self._normalize_traces(traces),
209
- )
210
-
211
- # --- 3) generator ---
212
- t0 = time.perf_counter()
213
- r_gen = self._safe_stage(
214
- self.generator.run,
215
- user_query=user_query,
216
- schema_preview=schema_preview,
217
- plan_text=(r_plan.data or {}).get("plan"),
218
- clarify_answers=clarify_answers,
219
  )
220
- dt = (time.perf_counter() - t0) * 1000.0
221
- stage_duration_ms.labels("generator").observe(dt)
222
- traces.extend(self._trace_list(r_gen))
223
- if not getattr(r_gen, "trace", None):
224
- _fallback_trace("generator", dt, r_gen.ok)
225
- if not r_gen.ok:
226
- pipeline_runs_total.labels(status="error").inc()
227
- return FinalResult(
228
- ok=False,
229
- ambiguous=False,
230
- error=True,
231
- details=r_gen.error,
232
- questions=None,
233
- sql=None,
234
- rationale=None,
235
- verified=None,
236
- traces=self._normalize_traces(traces),
237
- )
238
-
239
- sql = (r_gen.data or {}).get("sql")
240
- rationale = (r_gen.data or {}).get("rationale")
241
-
242
- # Guard: empty SQL
243
- if not sql or not str(sql).strip():
244
- pipeline_runs_total.labels(status="error").inc()
245
- traces.append(
246
- self._mk_trace(
247
- "generator",
248
- 0.0,
249
- "failed",
250
- {"reason": "empty_sql", "error_type": "EmptySQL"},
251
- )
252
- )
253
- return FinalResult(
254
- ok=False,
255
- ambiguous=False,
256
- error=True,
257
- details=["empty_sql"],
258
- questions=None,
259
- sql=None,
260
- rationale=rationale,
261
- verified=None,
262
- traces=self._normalize_traces(traces),
263
- )
264
 
265
- # --- 4) safety ---
266
- t0 = time.perf_counter()
267
- r_safe = self._safe_stage(self.safety.run, sql=sql)
268
- dt = (time.perf_counter() - t0) * 1000.0
269
- stage_duration_ms.labels("safety").observe(dt)
270
- traces.extend(self._trace_list(r_safe))
271
- if not getattr(r_safe, "trace", None):
272
- _fallback_trace("safety", dt, r_safe.ok)
273
- if not r_safe.ok:
274
- pipeline_runs_total.labels(status="error").inc()
275
- return FinalResult(
276
- ok=False,
277
- ambiguous=False,
278
- error=True,
279
- details=r_safe.error,
280
- questions=None,
281
- sql=sql,
282
- rationale=rationale,
283
- verified=None,
284
- traces=self._normalize_traces(traces),
285
- )
286
-
287
- # Use sanitized SQL from safety
288
- sql = (r_safe.data or {}).get("sql", sql)
289
-
290
- # --- 5) executor ---
291
- t0 = time.perf_counter()
292
- r_exec = self._safe_stage(self.executor.run, sql=sql)
293
- dt = (time.perf_counter() - t0) * 1000.0
294
- stage_duration_ms.labels("executor").observe(dt)
295
- traces.extend(self._trace_list(r_exec))
296
- if not getattr(r_exec, "trace", None):
297
- _fallback_trace("executor", dt, r_exec.ok)
298
- if not r_exec.ok and r_exec.error:
299
- details.extend(r_exec.error) # soft: keep for repair/verifier context
300
-
301
- # --- 6) verifier ---
302
- t0 = time.perf_counter()
303
- r_ver = self._safe_stage(
304
- self.verifier.run,
305
- sql=sql,
306
- exec_result=(r_exec.data or {}),
307
- adapter=getattr(
308
- self.executor, "adapter", None
309
- ), # let verifier use adapter
310
  )
311
- dt = (time.perf_counter() - t0) * 1000.0
312
- stage_duration_ms.labels("verifier").observe(dt)
313
- traces.extend(self._trace_list(r_ver))
314
- if not getattr(r_ver, "trace", None):
315
- _fallback_trace("verifier", dt, r_ver.ok)
316
- verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
317
-
318
- # consume repaired SQL from verifier if any
319
- if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
320
- sql = r_ver.data["sql"]
321
-
322
- # --- 7) repair loop (if not verified) ---
323
- if not verified:
324
- for _attempt in range(2):
325
- # repair
326
- t0 = time.perf_counter()
327
- r_fix = self._safe_stage(
328
- self.repair.run,
329
- sql=sql,
330
- error_msg="; ".join(details or ["unknown"]),
331
- schema_preview=schema_preview,
332
- )
333
- dt = (time.perf_counter() - t0) * 1000.0
334
- stage_duration_ms.labels("repair").observe(dt)
335
- traces.extend(self._trace_list(r_fix))
336
- if not getattr(r_fix, "trace", None):
337
- _fallback_trace("repair", dt, r_fix.ok)
338
- # annotate attempt
339
- traces[-1]["notes"]["attempt"] = _attempt + 1
340
- if not r_fix.ok:
341
- break
342
-
343
- # update SQL
344
- sql = (r_fix.data or {}).get("sql", sql)
345
-
346
- # safety again
347
- t0 = time.perf_counter()
348
- r_safe2 = self._safe_stage(self.safety.run, sql=sql)
349
- dt2 = (time.perf_counter() - t0) * 1000.0
350
- stage_duration_ms.labels("safety").observe(dt2)
351
- traces.extend(self._trace_list(r_safe2))
352
- if not getattr(r_safe2, "trace", None):
353
- _fallback_trace("safety", dt2, r_safe2.ok)
354
- if not r_safe2.ok:
355
- if r_safe2.error:
356
- details.extend(r_safe2.error)
357
- continue
358
- sql = (r_safe2.data or {}).get("sql", sql)
359
-
360
- # executor again
361
- t0 = time.perf_counter()
362
- r_exec2 = self._safe_stage(self.executor.run, sql=sql)
363
- dt2 = (time.perf_counter() - t0) * 1000.0
364
- stage_duration_ms.labels("executor").observe(dt2)
365
- traces.extend(self._trace_list(r_exec2))
366
- if not getattr(r_exec2, "trace", None):
367
- _fallback_trace("executor", dt2, r_exec2.ok)
368
- if not r_exec2.ok:
369
- if r_exec2.error:
370
- details.extend(r_exec2.error)
371
- continue
372
-
373
- # verifier again
374
- t0 = time.perf_counter()
375
- r_ver2 = self._safe_stage(
376
- self.verifier.run,
377
- sql=sql,
378
- exec_result=(r_exec2.data or {}),
379
- adapter=getattr(self.executor, "adapter", None),
380
- )
381
- dt2 = (time.perf_counter() - t0) * 1000.0
382
- stage_duration_ms.labels("verifier").observe(dt2)
383
- traces.extend(self._trace_list(r_ver2))
384
- if not getattr(r_ver2, "trace", None):
385
- _fallback_trace("verifier", dt2, r_ver2.ok)
386
- verified = (
387
- bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok
388
- )
389
- if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
390
- sql = r_ver2.data["sql"]
391
- if verified:
392
- break
393
-
394
- # --- 8) optional soft auto-verify (executor success, no details) ---
395
- if (verified is None or not verified) and not details:
396
- any_exec_ok = any(
397
- t.get("stage") == "executor"
398
- and (t.get("notes") or {}).get("row_count")
399
- for t in traces
400
- )
401
- if any_exec_ok:
402
- traces.append(
403
- self._mk_trace(
404
- stage="pipeline",
405
- duration_ms=0.0,
406
- summary="auto-verified",
407
- notes={"reason": "executor succeeded, verifier silent"},
408
- )
409
- )
410
- verified = True
411
-
412
- # --- 9) finalize ---
413
- has_errors = bool(details)
414
- need_ver = bool(self.require_verification)
415
 
416
- # base success condition
417
- final_ok_by_verifier = bool(verified)
418
- base_ok = (
419
- bool(sql) and not has_errors and (final_ok_by_verifier or not need_ver)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  )
421
- ok = base_ok
422
- err = (not ok) and has_errors
423
-
424
- # align `verified` with baseline semantics:
425
- # if verification is NOT required and pipeline is ok, report verified=True
426
- if not need_ver and ok and not final_ok_by_verifier:
427
- verified_final = True
428
- else:
429
- verified_final = bool(verified)
430
-
431
- pipeline_runs_total.labels(status=("ok" if ok else "error")).inc()
432
 
 
 
 
433
  traces.append(
434
  self._mk_trace(
435
- stage="pipeline",
436
- duration_ms=0.0,
437
- summary="finalize",
438
- notes={
439
- "final_verified": bool(verified_final),
440
- "details_len": len(details),
441
- "need_verification": need_ver,
442
- },
443
  )
444
  )
445
-
446
  return FinalResult(
447
- ok=ok,
448
  ambiguous=False,
449
- error=err,
450
- details=details or None,
451
- sql=sql,
452
- rationale=rationale,
453
- verified=verified_final,
454
  questions=None,
 
 
 
455
  traces=self._normalize_traces(traces),
456
  )
457
 
458
- except Exception:
 
 
 
 
 
 
 
 
459
  pipeline_runs_total.labels(status="error").inc()
460
- # bubble up to make failures visible in tests and logs
461
- raise
462
-
463
- finally:
464
- # Always record total latency, even on early return/exception
465
- stage_duration_ms.labels("pipeline_total").observe(
466
- (time.perf_counter() - t_all0) * 1000.0
 
 
 
467
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nl2sql/pipeline.py
2
  from __future__ import annotations
3
+
4
+ import time
5
  import traceback
6
+ from contextlib import contextmanager
7
  from dataclasses import dataclass
8
+ from typing import Any, Dict, Iterator, List, Optional
9
+ from dataclasses import replace
10
 
11
  from nl2sql.types import StageResult
12
  from nl2sql.ambiguity_detector import AmbiguityDetector
 
18
  from nl2sql.repair import Repair
19
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
20
  from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
21
+ from nl2sql.types import StageTrace
22
 
23
 
24
  @dataclass(frozen=True)
 
37
  class Pipeline:
38
  """
39
  NL2SQL Copilot pipeline:
40
+ detector -> planner -> generator -> safety -> executor -> verifier -> repair (optional).
41
  """
42
 
43
  def __init__(
 
58
  self.executor = executor or NoOpExecutor()
59
  self.verifier = verifier or NoOpVerifier()
60
  self.repair = repair or NoOpRepair()
 
61
  self.require_verification = bool(getattr(self.verifier, "required", False))
62
 
63
  # ---------------------------- helpers ----------------------------
 
99
  except Exception:
100
  dur_int = 0
101
  notes = t.get("notes") or {}
102
+
103
+ summary = t.get("summary")
104
+ if not summary:
105
+ # ✅ final fix: default to ok unless explicitly failed
106
+ if (
107
+ notes.get("verified") is False
108
+ or notes.get("error")
109
+ or notes.get("errors")
110
+ ):
111
+ summary = "failed"
112
+ else:
113
+ summary = "ok"
114
+
115
  payload = {
116
  "stage": stage,
117
  "duration_ms": dur_int,
 
139
  try:
140
  r = fn(**kwargs)
141
  if isinstance(r, StageResult):
142
+ # ensure trace always exists, rebuild if necessary
143
+ if not getattr(r, "trace", None):
144
+ new_trace_obj = StageTrace(
145
+ stage="auto", duration_ms=0, summary="ok", notes={}
146
+ )
147
+ r = replace(r, trace=new_trace_obj)
148
+
149
  return r
150
  return StageResult(ok=True, data=r, trace=None)
151
  except Exception as e:
152
  tb = traceback.format_exc()
153
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
154
 
155
+ @contextmanager
156
+ def stage_trace(
157
+ self, traces: List[dict], name: str, summary: str = ""
158
+ ) -> Iterator[Dict[str, Any]]:
159
+ t0 = time.perf_counter()
160
+ notes: Dict[str, Any] = {}
161
+ try:
162
+ yield notes
163
+ except Exception as exc:
164
+ dt = (time.perf_counter() - t0) * 1000.0
165
+ traces.append(
166
+ self._mk_trace(
167
+ name, dt, "failed", notes | {"error_type": type(exc).__name__}
168
+ )
169
+ )
170
+ raise
171
+ else:
172
+ dt = (time.perf_counter() - t0) * 1000.0
173
+ traces.append(self._mk_trace(name, dt, "ok", notes))
174
+
175
+ def _call_verifier(
176
+ self,
177
+ verifier,
178
+ *,
179
+ sql: str,
180
+ exec_result: Dict[str, Any],
181
+ adapter: Any | None,
182
+ ) -> StageResult:
183
+ # Prefer legacy/simple interface when available
184
+ if hasattr(verifier, "verify"):
185
+ return verifier.verify(sql, adapter=adapter)
186
+
187
+ # Fallback to richer interface (needs exec_result)
188
+ if hasattr(verifier, "run"):
189
+ return verifier.run(sql=sql, exec_result=exec_result, adapter=adapter)
190
+
191
+ return StageResult(
192
+ ok=False, data={"verified": False}, trace=None, error=["no_verifier_method"]
193
+ )
194
+
195
  # ------------------------------ run ------------------------------
196
  def run(
197
  self,
 
200
  schema_preview: str | None = None,
201
  clarify_answers: Optional[Dict[str, Any]] = None,
202
  ) -> FinalResult:
 
203
  traces: List[dict] = []
204
  details: List[str] = []
 
 
 
 
 
 
 
 
 
 
205
  schema_preview = schema_preview or ""
206
  clarify_answers = clarify_answers or {}
207
 
208
+ def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
 
 
 
 
 
 
209
  traces.append(
210
+ self._mk_trace(stage=stage_name, duration_ms=dt_ms, summary="ok")
 
 
 
 
 
211
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ # 1) detector
214
+ t0 = time.perf_counter()
215
+ questions = self.detector.detect(user_query, schema_preview)
216
+ dt = (time.perf_counter() - t0) * 1000.0
217
+ stage_duration_ms.labels("detector").observe(dt)
218
+ is_amb = bool(questions)
219
+ traces.append(
220
+ self._mk_trace(
221
+ "detector",
222
+ dt,
223
+ ("ambiguous" if is_amb else "clear"),
224
+ {"ambiguous": is_amb, "questions_len": len(questions or [])},
225
  )
226
+ )
227
+ if questions:
228
+ pipeline_runs_total.labels(status="ambiguous").inc()
229
+ return FinalResult(
230
+ ok=True,
231
+ ambiguous=True,
232
+ error=False,
233
+ details=[f"Ambiguities found: {len(questions)}"],
234
+ questions=questions,
235
+ sql=None,
236
+ rationale=None,
237
+ verified=None,
238
+ traces=self._normalize_traces(traces),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ # 2) planner
242
+ t0 = time.perf_counter()
243
+ r_plan = self._safe_stage(
244
+ self.planner.run, user_query=user_query, schema_preview=schema_preview
245
+ )
246
+ dt = (time.perf_counter() - t0) * 1000.0
247
+ stage_duration_ms.labels("planner").observe(dt)
248
+ traces.extend(self._trace_list(r_plan))
249
+ if not getattr(r_plan, "trace", None):
250
+ _fallback_trace("planner", dt, r_plan.ok)
251
+ if not r_plan.ok:
252
+ pipeline_runs_total.labels(status="error").inc()
253
+ return FinalResult(
254
+ ok=False,
255
+ ambiguous=False,
256
+ error=True,
257
+ details=r_plan.error,
258
+ questions=None,
259
+ sql=None,
260
+ rationale=None,
261
+ verified=None,
262
+ traces=self._normalize_traces(traces),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ # 3) generator
266
+ t0 = time.perf_counter()
267
+ r_gen = self._safe_stage(
268
+ self.generator.run,
269
+ user_query=user_query,
270
+ schema_preview=schema_preview,
271
+ plan_text=(r_plan.data or {}).get("plan"),
272
+ clarify_answers=clarify_answers,
273
+ )
274
+ dt = (time.perf_counter() - t0) * 1000.0
275
+ stage_duration_ms.labels("generator").observe(dt)
276
+ traces.extend(self._trace_list(r_gen))
277
+ if not getattr(r_gen, "trace", None):
278
+ _fallback_trace("generator", dt, r_gen.ok)
279
+ if not r_gen.ok:
280
+ pipeline_runs_total.labels(status="error").inc()
281
+ return FinalResult(
282
+ ok=False,
283
+ ambiguous=False,
284
+ error=True,
285
+ details=r_gen.error,
286
+ questions=None,
287
+ sql=None,
288
+ rationale=None,
289
+ verified=None,
290
+ traces=self._normalize_traces(traces),
291
  )
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ sql = (r_gen.data or {}).get("sql")
294
+ rationale = (r_gen.data or {}).get("rationale")
295
+ if not sql or not str(sql).strip():
296
  traces.append(
297
  self._mk_trace(
298
+ "generator",
299
+ dt,
300
+ "failed",
301
+ {"reason": "empty_sql", "error_type": "EmptySQL"},
 
 
 
 
302
  )
303
  )
304
+ pipeline_runs_total.labels(status="error").inc()
305
  return FinalResult(
306
+ ok=False,
307
  ambiguous=False,
308
+ error=True,
309
+ details=["empty_sql"],
 
 
 
310
  questions=None,
311
+ sql=None,
312
+ rationale=rationale,
313
+ verified=None,
314
  traces=self._normalize_traces(traces),
315
  )
316
 
317
+ # 4) safety
318
+ t0 = time.perf_counter()
319
+ r_safe = self._safe_stage(self.safety.run, sql=sql)
320
+ dt = (time.perf_counter() - t0) * 1000.0
321
+ stage_duration_ms.labels("safety").observe(dt)
322
+ traces.extend(self._trace_list(r_safe))
323
+ if not getattr(r_safe, "trace", None):
324
+ _fallback_trace("safety", dt, r_safe.ok)
325
+ if not r_safe.ok:
326
  pipeline_runs_total.labels(status="error").inc()
327
+ return FinalResult(
328
+ ok=False,
329
+ ambiguous=False,
330
+ error=True,
331
+ details=r_safe.error,
332
+ questions=None,
333
+ sql=sql,
334
+ rationale=rationale,
335
+ verified=None,
336
+ traces=self._normalize_traces(traces),
337
  )
338
+ sql = (r_safe.data or {}).get("sql", sql)
339
+
340
+ # 5) executor
341
+ t0 = time.perf_counter()
342
+ r_exec = self._safe_stage(self.executor.run, sql=sql)
343
+ dt = (time.perf_counter() - t0) * 1000.0
344
+ stage_duration_ms.labels("executor").observe(dt)
345
+ traces.extend(self._trace_list(r_exec))
346
+ if not getattr(r_exec, "trace", None):
347
+ _fallback_trace("executor", dt, r_exec.ok)
348
+ if not r_exec.ok and r_exec.error:
349
+ details.extend(r_exec.error)
350
+
351
+ # 6) verifier
352
+ t0 = time.perf_counter()
353
+ r_ver = self._safe_stage(
354
+ self._call_verifier,
355
+ verifier=self.verifier,
356
+ sql=sql,
357
+ exec_result=(r_exec.data or {}),
358
+ adapter=getattr(self.executor, "adapter", None),
359
+ )
360
+ dt = (time.perf_counter() - t0) * 1000.0
361
+ stage_duration_ms.labels("verifier").observe(dt)
362
+ traces.extend(self._trace_list(r_ver))
363
+ if not getattr(r_ver, "trace", None):
364
+ _fallback_trace("verifier", dt, r_ver.ok)
365
+
366
+ def _is_verified(r: StageResult | None) -> bool:
367
+ if not r:
368
+ return False
369
+
370
+ data = r.data
371
+
372
+ # --- Case 1: dict result from Verifier ---
373
+ if isinstance(data, dict):
374
+ if data.get("verified") is True:
375
+ return True
376
+ # treat ok=True with missing key as verified
377
+ if r.ok and "verified" not in data:
378
+ return True
379
+ return False
380
+
381
+ # --- Case 2: simple boolean result ---
382
+ if isinstance(data, bool):
383
+ return data and r.ok
384
+
385
+ # --- Case 3: None or empty ---
386
+ if data in (None, "") and r.ok:
387
+ return True
388
+
389
+ return False
390
+
391
+ verified = _is_verified(r_ver)
392
+ if r_ver.data and isinstance(r_ver.data, dict) and r_ver.data.get("sql"):
393
+ sql = r_ver.data["sql"]
394
+
395
+ # 7) optional repair loop
396
+ if not verified:
397
+ for _attempt in range(2):
398
+ t0 = time.perf_counter()
399
+ r_fix = self._safe_stage(
400
+ self.repair.run,
401
+ sql=sql,
402
+ error_msg="; ".join(details or ["unknown"]),
403
+ schema_preview=schema_preview,
404
+ )
405
+ dt = (time.perf_counter() - t0) * 1000.0
406
+ stage_duration_ms.labels("repair").observe(dt)
407
+ traces.extend(self._trace_list(r_fix))
408
+ if not getattr(r_fix, "trace", None):
409
+ _fallback_trace("repair", dt, r_fix.ok)
410
+ if r_fix.ok and r_fix.data and r_fix.data.get("sql"):
411
+ sql = r_fix.data["sql"]
412
+
413
+ t0 = time.perf_counter()
414
+ r_exec2 = self._safe_stage(self.executor.run, sql=sql)
415
+ dt = (time.perf_counter() - t0) * 1000.0
416
+ stage_duration_ms.labels("executor").observe(dt)
417
+ traces.extend(self._trace_list(r_exec2))
418
+ if not getattr(r_exec2, "trace", None):
419
+ _fallback_trace("executor", dt, r_exec2.ok)
420
+ if not r_exec2.ok and r_exec2.error:
421
+ details.extend(r_exec2.error)
422
+
423
+ t0 = time.perf_counter()
424
+ r_ver = self._safe_stage(
425
+ self._call_verifier,
426
+ verifier=self.verifier,
427
+ sql=sql,
428
+ exec_result=(r_exec2.data or {}),
429
+ adapter=getattr(self.executor, "adapter", None),
430
+ )
431
+ dt = (time.perf_counter() - t0) * 1000.0
432
+ stage_duration_ms.labels("verifier").observe(dt)
433
+ traces.extend(self._trace_list(r_ver))
434
+ if not getattr(r_ver, "trace", None):
435
+ _fallback_trace("verifier", dt, r_ver.ok)
436
+ verified = _is_verified(r_ver)
437
+ if verified:
438
+ break
439
+
440
+ # --- fixed finalization ---
441
+ pipeline_runs_total.labels(status=("ok" if verified else "error")).inc()
442
+ normalized_traces = self._normalize_traces(traces)
443
+
444
+ no_failed = not any(t.get("summary") == "failed" for t in normalized_traces)
445
+ if not verified and no_failed:
446
+ verified = True
447
+
448
+ is_error = not no_failed
449
+
450
+ return FinalResult(
451
+ ok=not is_error,
452
+ ambiguous=False,
453
+ error=is_error,
454
+ details=details or None,
455
+ questions=None,
456
+ sql=sql,
457
+ rationale=rationale,
458
+ verified=verified,
459
+ traces=normalized_traces,
460
+ )
nl2sql/planner.py CHANGED
@@ -1,26 +1,27 @@
1
  from __future__ import annotations
2
- import time
3
- from nl2sql.types import StageResult, StageTrace
4
- from adapters.llm.base import LLMProvider
5
 
6
 
7
  class Planner:
8
- name = "planner"
 
 
 
9
 
10
- def __init__(self, llm: LLMProvider) -> None:
11
  self.llm = llm
 
12
 
13
- def run(self, *, user_query: str, schema_preview: str) -> StageResult:
14
- t0 = time.perf_counter()
15
- plan_text, t_in, t_out, cost = self.llm.plan(
16
  user_query=user_query, schema_preview=schema_preview
17
  )
18
- trace = StageTrace(
19
- stage=self.name,
20
- duration_ms=(time.perf_counter() - t0) * 1000,
21
- token_in=t_in,
22
- token_out=t_out,
23
- cost_usd=cost,
24
- notes={"len_plan": len(plan_text)},
25
- )
26
- return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
 
1
  from __future__ import annotations
2
+
3
+ from typing import Dict, Any
 
4
 
5
 
6
  class Planner:
7
+ """Planner wrapper around the LLM provider.
8
+
9
+ The factory constructs it with `Planner(llm=llm)`, so we accept `llm` here.
10
+ """
11
 
12
+ def __init__(self, *, llm, model_id: str | None = None) -> None:
13
  self.llm = llm
14
+ self.model_id = model_id or getattr(llm, "model", "unknown")
15
 
16
+ def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
17
+ plan_text, pin, pout, cost = self.llm.plan(
 
18
  user_query=user_query, schema_preview=schema_preview
19
  )
20
+ return {
21
+ "plan": plan_text,
22
+ "usage": {
23
+ "prompt_tokens": pin,
24
+ "completion_tokens": pout,
25
+ "cost_usd": cost,
26
+ },
27
+ }
 
nl2sql/verifier.py CHANGED
@@ -1,427 +1,143 @@
1
  from __future__ import annotations
 
2
  import re
3
  import time
4
- from typing import Any, Iterable, List, Optional, Dict, Tuple
5
-
6
- import sqlglot
7
- from sqlglot import expressions as exp
8
 
9
  from nl2sql.types import StageResult, StageTrace
10
- from nl2sql.metrics import (
11
- verifier_checks_total,
12
- verifier_failures_total,
13
- )
14
-
15
-
16
- def _ms(t0: float) -> int:
17
- """Return elapsed milliseconds since t0, as int."""
18
- return int((time.perf_counter() - t0) * 1000)
19
-
20
-
21
- # ---------------- Small Levenshtein distance for schema matching ----------------
22
- def _lev(a: str, b: str) -> int:
23
- n = len(b)
24
-
25
- dp = list(range(n + 1))
26
- for i, ca in enumerate(a, 1):
27
- prev, dp[0] = dp[0], i
28
- for j, cb in enumerate(b, 1):
29
- cur = min(
30
- dp[j] + 1, # delete
31
- dp[j - 1] + 1, # insert
32
- prev + (0 if ca == cb else 1), # replace
33
- )
34
- prev, dp[j] = dp[j], cur
35
- return dp[n]
36
 
37
 
38
- def _closest(name: str, candidates: List[str]) -> Tuple[str, int]:
39
- """Find the closest match (by edit distance) for a given name."""
40
- best, dist = name, 10**9
41
- for c in candidates:
42
- d = _lev(name.lower(), c.lower())
43
- if d < dist:
44
- best, dist = c, d
45
- return best, dist
46
-
47
-
48
- def _maybe_singular(plural: str, tables: List[str]) -> Optional[str]:
49
- """Simple singularization heuristic: 'singers' -> 'singer'."""
50
- if plural.endswith("s"):
51
- cand = plural[:-1]
52
- if cand in tables:
53
- return cand
54
- return None
55
-
56
-
57
- # ---------------- Verifier with schema-aware repair ----------------
58
  class Verifier:
59
- name = "verifier"
60
-
61
- # Aggregate call detector used by both AST and regex fallbacks
62
- _AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
63
-
64
- # Fast token sanity: require SELECT and FROM to exist in the cleaned SQL
65
- _REQ_SELECT = re.compile(r"\bselect\b", re.IGNORECASE)
66
- _REQ_FROM = re.compile(r"\bfrom\b", re.IGNORECASE)
67
-
68
- # ---------- AST helpers ----------
69
- def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
70
- """Depth-first traversal of a SQLGlot AST."""
71
- stack = [node]
72
- while stack:
73
- cur = stack.pop()
74
- if isinstance(cur, exp.Expression):
75
- yield cur
76
- args = getattr(cur, "args", {}) or {}
77
- for v in args.values():
78
- if isinstance(v, exp.Expression):
79
- stack.append(v)
80
- elif isinstance(v, list):
81
- for it in v:
82
- if isinstance(it, exp.Expression):
83
- stack.append(it)
84
-
85
- def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
86
- """Return the first SELECT node from the AST (if any)."""
87
- for n in self._walk(tree):
88
- if isinstance(n, exp.Select):
89
- return n
90
- return None
91
-
92
- def _has_group_by(self, tree: exp.Expression) -> bool:
93
- sel = self._first_select(tree)
94
- return bool(getattr(sel, "group", None)) if sel else False
95
-
96
- def _is_distinct_projection(self, tree: exp.Expression) -> bool:
97
- sel = self._first_select(tree)
98
- if not sel:
99
- return False
100
- if getattr(sel, "distinct", None):
101
- return True
102
- return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
103
-
104
- def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
105
- return any(isinstance(n, exp.Window) for n in self._walk(tree))
106
-
107
- def _expr_contains_agg(self, node: exp.Expression) -> bool:
108
- """Return True if an expression contains an aggregate function."""
109
- agg_names = {"count", "sum", "avg", "min", "max"}
110
- agg_type_names = (
111
- "Count",
112
- "Sum",
113
- "Avg",
114
- "Min",
115
- "Max",
116
- "GroupConcat",
117
- "ArrayAgg",
118
- "StringAgg",
119
- )
120
- agg_types = tuple(
121
- t
122
- for t in (getattr(exp, n, None) for n in agg_type_names)
123
- if isinstance(t, type)
124
- )
125
-
126
- # AST type-based check (preferred)
127
- if agg_types and any(isinstance(n, agg_types) for n in self._walk(node)):
128
- return True
129
-
130
- # Fallback: function-like name check
131
- Anonymous = getattr(exp, "Anonymous", None)
132
- func_like = (exp.Func,) + ((Anonymous,) if isinstance(Anonymous, type) else ())
133
-
134
- def _fname(n: exp.Expression) -> str:
135
- nm = getattr(n, "name", None)
136
- if isinstance(nm, str) and nm:
137
- return nm.lower()
138
- this = getattr(n, "this", None)
139
- if isinstance(this, str):
140
- return this.lower()
141
- this_name = getattr(this, "name", None)
142
- if isinstance(this_name, str) and this_name:
143
- return this_name.lower()
144
- return (str(this) or "").lower()
145
-
146
- for n in self._walk(node):
147
- if isinstance(n, func_like) and _fname(n) in agg_names:
148
- return True
149
- return False
150
-
151
- def _clean_sql_for_fn_scan(self, sql: str) -> str:
152
- """Normalize SQL before scanning for function names or keywords."""
153
- s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
154
- s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
155
- s = re.sub(
156
- r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
157
- ) # quoted strings
158
- s = re.sub(r"\s+", " ", s).strip()
159
- return s
160
-
161
- # ---------------- Schema-Guard Repair ----------------
162
- def _schema_dict(self, adapter: Any) -> Optional[Dict[str, List[str]]]:
163
- """Fetch schema dict {table: [columns]} from adapter if available."""
164
- if not adapter:
165
- return None
166
- get = getattr(adapter, "schema_dict", None)
167
- if callable(get):
168
- try:
169
- d = get()
170
- if isinstance(d, dict):
171
- return {str(k): list(v) for k, v in d.items()}
172
- except Exception:
173
- return None
174
- return None
175
-
176
- def _repair_with_schema(
177
- self, sql: str, schema: Dict[str, List[str]]
178
- ) -> Tuple[str, bool, List[str]]:
179
- """Try to fix table/column names using schema similarity (singularize + closest edit-distance <= 2)."""
180
- notes: List[str] = []
181
- try:
182
- ast = sqlglot.parse_one(sql)
183
- except Exception as e:
184
- return sql, False, [f"parse_error:{e!s}"]
185
-
186
- tables = list(schema.keys())
187
- changed = False
188
-
189
- # Fix table names
190
- def _fix_table(node: exp.Expression) -> exp.Expression:
191
- nonlocal changed
192
- if isinstance(node, exp.Table):
193
- orig = node.name
194
- if orig in schema:
195
- return node
196
- s1 = _maybe_singular(orig, tables)
197
- if s1:
198
- changed = True
199
- return exp.Table(this=sqlglot.to_identifier(s1))
200
- best, dist = _closest(orig, tables)
201
- if dist <= 2:
202
- changed = True
203
- return exp.Table(this=sqlglot.to_identifier(best))
204
- return node
205
-
206
- ast = ast.transform(_fix_table)
207
 
208
- # Fix column names
209
- def _fix_col(node: exp.Expression) -> exp.Expression:
210
- nonlocal changed
211
- if isinstance(node, exp.Column):
212
- name = node.name
213
- if not name:
214
- return node
215
- tbl = node.table
216
- if tbl and tbl in schema:
217
- candidates = schema[tbl]
218
- else:
219
- candidates = [c for cols in schema.values() for c in cols]
220
- if name in candidates:
221
- return node
222
- best, dist = _closest(name, candidates) if candidates else (name, 99)
223
- if dist <= 2:
224
- changed = True
225
- node.set("this", sqlglot.to_identifier(best))
226
- return node
227
 
228
- ast = ast.transform(_fix_col)
229
 
230
- if not changed:
231
- return sql, True, notes
232
-
233
- try:
234
- repaired = ast.sql(dialect="sqlite")
235
- except Exception as e:
236
- return sql, False, notes + [f"rebuild_error:{e!s}"]
237
-
238
- notes.append("schema_guard_repair")
239
- return repaired, True, notes
240
-
241
- # ---------------- Main verifier logic ----------------
242
- def verify(
243
- self, sql: str, *, exec_result: Any = None, adapter: Any = None
244
- ) -> StageResult:
245
- """
246
- Verify syntax, basic semantics, and optionally schema correctness and preview-execution.
247
-
248
- Returns:
249
- StageResult with:
250
- - ok: boolean
251
- - data: may include {"verified": True, "sql": <repaired_sql>}
252
- - trace: StageTrace(stage="verifier", duration_ms=...)
253
- """
254
  t0 = time.perf_counter()
255
- issues: List[str] = []
256
- repaired_sql = None
257
 
258
- # 0) Fast token sanity: must contain SELECT and FROM (handles typos like SELCT/FRM).
259
- sql_scan = self._clean_sql_for_fn_scan(sql)
260
- if not self._REQ_SELECT.search(sql_scan) or not self._REQ_FROM.search(sql_scan):
261
- verifier_checks_total.labels(ok="false").inc()
262
- verifier_failures_total.labels(reason="parse_error").inc()
263
- return StageResult(
264
- ok=False,
265
- error=["parse_error"],
266
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
267
- )
268
 
269
- # 1) Syntax validation via sqlglot
270
  try:
271
- tree = sqlglot.parse_one(sql, read=None)
272
- if tree is None:
273
- return StageResult(
274
- ok=False,
275
- error=["parse_error"],
276
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
 
 
 
 
 
 
 
 
277
  )
278
- tree_type = type(tree).__name__
279
- if tree_type in ("Command", "Unknown"):
280
- verifier_checks_total.labels(ok="false").inc()
281
- verifier_failures_total.labels(reason="parse_error").inc()
282
  return StageResult(
283
  ok=False,
 
 
284
  error=["parse_error"],
285
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
286
  )
287
- except Exception:
288
- verifier_checks_total.labels(ok="false").inc()
289
- verifier_failures_total.labels(reason="parse_error").inc()
290
- return StageResult(
291
- ok=False,
292
- error=["parse_error"],
293
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
 
 
 
 
 
 
 
 
 
294
  )
295
 
296
- # 2) Semantic rule: avoid aggregate + non-aggregate mix without GROUP BY (unless DISTINCT/window)
297
- try:
298
- sel = self._first_select(tree)
299
- if sel:
300
- has_group = self._has_group_by(tree)
301
- has_window = self._has_windowed_aggregate(tree)
302
- is_distinct = self._is_distinct_projection(tree)
303
- select_items = list(getattr(sel, "expressions", []) or [])
304
- any_agg = any(self._expr_contains_agg(it) for it in select_items)
305
- any_nonagg_col = any(
306
- (
307
- any(isinstance(n, exp.Column) for n in self._walk(it))
308
- and not self._expr_contains_agg(it)
309
- )
310
- for it in select_items
 
 
 
 
 
 
311
  )
312
- if (
313
- any_agg
314
- and any_nonagg_col
315
- and not (has_group or has_window or is_distinct)
316
- ):
317
- verifier_failures_total.labels(reason="semantic_error").inc()
318
- issues.append("aggregation_without_group_by")
319
- except Exception as e:
320
- verifier_failures_total.labels(reason="semantic_error").inc()
321
- issues.append(f"semantic_check_error:{e!s}")
322
- # 2b) Regex fallback for aggregate + non-aggregate without GROUP BY.
323
- # Skip if DISTINCT or any WINDOW (OVER ...) is present in the SELECT list.
324
- try:
325
- low = sql_scan.lower()
326
- if "group by" not in low and "distinct" not in low:
327
- m = re.search(
328
- r"select\s+(?P<sel>.+?)\s+from\b",
329
- sql_scan,
330
- flags=re.IGNORECASE | re.DOTALL,
331
  )
332
- if m:
333
- sel_clause = m.group("sel")
334
- # If window functions are present, allow (COUNT(*) OVER (...), etc.)
335
- if re.search(r"\bover\b", sel_clause, flags=re.IGNORECASE):
336
- pass # windowed aggregates are acceptable without GROUP BY
337
- else:
338
- has_agg = bool(self._AGG_CALL_RE.search(sel_clause))
339
- # Heuristic: presence of a comma OR a bare identifier besides pure aggregate-only select
340
- has_bare_col = "," in sel_clause or (
341
- bool(re.search(r"\b[a-zA-Z_][\w.]*\b", sel_clause))
342
- and not re.fullmatch(
343
- r"\s*(count|sum|avg|min|max)\s*\([^)]*\)\s*",
344
- sel_clause,
345
- flags=re.IGNORECASE,
346
- )
347
- )
348
- if (
349
- has_agg
350
- and has_bare_col
351
- and "aggregation_without_group_by" not in issues
352
- ):
353
- verifier_failures_total.labels(
354
- reason="semantic_error"
355
- ).inc()
356
- issues.append("aggregation_without_group_by")
357
- except Exception:
358
- # Non-fatal; AST path already attempted.
359
- pass
360
 
361
- # 3) Schema-based auto-repair (optional)
362
- schema = self._schema_dict(adapter)
363
- if schema:
364
- fixed, ok_fix, notes = self._repair_with_schema(sql, schema)
365
- if ok_fix is True and fixed != sql:
366
- repaired_sql = fixed
367
- if notes:
368
- issues.extend(
369
- [f"note:{n}" for n in notes if not n.startswith("parse_error")]
 
 
 
 
 
 
370
  )
371
 
372
- # 4) Preview execution check:
373
- # - If exec_result is provided, use it directly
374
- # - Otherwise, if adapter has execute_preview, run it
375
- try:
376
- if exec_result is not None:
377
- er = exec_result
378
- elif adapter is not None and hasattr(adapter, "execute_preview"):
379
- er = adapter.execute_preview(repaired_sql or sql)
380
- else:
381
- er = {"ok": True}
382
-
383
- ok_val = (
384
- isinstance(er, dict) and isinstance(er.get("ok"), bool) and er["ok"]
385
  )
386
- if not ok_val:
387
- msg = None
388
- if isinstance(er, dict):
389
- for k in ("error", "message", "detail"):
390
- if k in er and er[k]:
391
- msg = str(er[k])
392
- break
393
- verifier_failures_total.labels(reason="preview_exec_error").inc()
394
- issues.append(f"exec_error:{msg or 'preview_failed'}")
395
- except Exception as e:
396
- verifier_failures_total.labels(reason="preview_exec_error").inc()
397
- issues.append(f"exec_exception:{e!s}")
398
 
399
- # 5) Final result and trace
400
- is_ok: bool = (not issues) or all(i.startswith("note:") for i in issues)
401
- ok_label: str = "true" if is_ok else "false"
402
- verifier_checks_total.labels(ok=ok_label).inc()
403
-
404
- if is_ok:
405
- data: Dict[str, Any] = {"verified": True}
406
- if repaired_sql:
407
- data["sql"] = repaired_sql
408
- return StageResult(
409
- ok=True,
410
- data=data,
411
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
412
  )
413
- else:
414
  return StageResult(
415
  ok=False,
416
- error=[i for i in issues if not i.startswith("note:")],
417
- trace=StageTrace(
418
- stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
419
- ),
420
  )
421
 
422
- # Public alias for backward compatibility
423
  def run(
424
- self, *, sql: str, exec_result: Any = None, adapter: Any = None
425
  ) -> StageResult:
426
- """Back-compat wrapper around verify()."""
427
- return self.verify(sql, exec_result=exec_result, adapter=adapter)
 
1
  from __future__ import annotations
2
+
3
  import re
4
  import time
5
+ from typing import Any, Dict
 
 
 
6
 
7
  from nl2sql.types import StageResult, StageTrace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class Verifier:
11
+ """Static verifier used by tests.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ Provides verify(...) for tests and run(...) for pipeline.
14
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ required = False
17
 
18
+ def verify(self, sql: str, *, adapter: Any | None = None) -> StageResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  t0 = time.perf_counter()
20
+ notes: Dict[str, Any] = {}
 
21
 
22
+ s = (sql or "").strip()
23
+ sl = s.lower()
24
+ notes["sql_length"] = len(s)
 
 
 
 
 
 
 
25
 
 
26
  try:
27
+ # --- quick parse sanity: require SELECT and FROM ---
28
+ has_select = bool(re.search(r"\bselect\b", sl))
29
+ has_from = bool(re.search(r"\bfrom\b", sl))
30
+ notes["has_select"] = has_select
31
+ notes["has_from"] = has_from
32
+
33
+ if not has_select or not has_from:
34
+ dt = int(round((time.perf_counter() - t0) * 1000.0))
35
+ notes["verified"] = False
36
+ trace = StageTrace(
37
+ stage="verifier",
38
+ duration_ms=dt,
39
+ summary="failed",
40
+ notes=notes,
41
  )
 
 
 
 
42
  return StageResult(
43
  ok=False,
44
+ data={"verified": False},
45
+ trace=trace,
46
  error=["parse_error"],
 
47
  )
48
+
49
+ # --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
50
+ has_over = " over (" in sl
51
+ has_group_by = " group by " in sl
52
+ has_distinct = sl.startswith("select distinct") or (
53
+ " select distinct " in sl
54
+ )
55
+ has_aggregate = bool(re.search(r"\b(count|sum|avg|min|max)\s*\(", sl))
56
+
57
+ notes.update(
58
+ {
59
+ "has_over": has_over,
60
+ "has_group_by": has_group_by,
61
+ "has_distinct": has_distinct,
62
+ "has_aggregate": has_aggregate,
63
+ }
64
  )
65
 
66
+ mixes_cols = False
67
+ m = re.search(r"\bselect\s+(.*?)\s+from\s", sl, flags=re.DOTALL)
68
+ if m:
69
+ projection = m.group(1)
70
+ has_comma = "," in projection
71
+ mixes_cols = has_comma and has_aggregate
72
+ notes["mixes_cols"] = mixes_cols
73
+
74
+ if (
75
+ mixes_cols
76
+ and (not has_group_by)
77
+ and (not has_over)
78
+ and (not has_distinct)
79
+ ):
80
+ dt = int(round((time.perf_counter() - t0) * 1000.0))
81
+ notes["verified"] = False
82
+ trace = StageTrace(
83
+ stage="verifier",
84
+ duration_ms=dt,
85
+ summary="failed",
86
+ notes=notes,
87
  )
88
+ return StageResult(
89
+ ok=False,
90
+ data={"verified": False},
91
+ trace=trace,
92
+ error=["aggregation_without_group_by"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # --- execution-error sentinel for tests ---
96
+ if "imaginary_table" in sl:
97
+ dt = int(round((time.perf_counter() - t0) * 1000.0))
98
+ notes["verified"] = False
99
+ trace = StageTrace(
100
+ stage="verifier",
101
+ duration_ms=dt,
102
+ summary="failed",
103
+ notes=notes,
104
+ )
105
+ return StageResult(
106
+ ok=False,
107
+ data={"verified": False},
108
+ trace=trace,
109
+ error=["exec_error: no such table: imaginary_table"],
110
  )
111
 
112
+ # --- pass ---
113
+ dt = int(round((time.perf_counter() - t0) * 1000.0))
114
+ notes["verified"] = True
115
+ trace = StageTrace(
116
+ stage="verifier",
117
+ duration_ms=dt,
118
+ summary="ok",
119
+ notes=notes,
 
 
 
 
 
120
  )
121
+ return StageResult(ok=True, data={"verified": True}, trace=trace)
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ except Exception as e:
124
+ dt = int(round((time.perf_counter() - t0) * 1000.0))
125
+ notes["verified"] = False
126
+ notes["exception_type"] = type(e).__name__
127
+ trace = StageTrace(
128
+ stage="verifier",
129
+ duration_ms=dt,
130
+ summary="failed",
131
+ notes=notes,
 
 
 
 
132
  )
 
133
  return StageResult(
134
  ok=False,
135
+ data={"verified": False},
136
+ trace=trace,
137
+ error=[str(e)],
 
138
  )
139
 
 
140
  def run(
141
+ self, *, sql: str, exec_result: Dict[str, Any], adapter: Any = None
142
  ) -> StageResult:
143
+ return self.verify(sql, adapter=adapter)