Melika Kheirieh commited on
Commit
b794494
·
1 Parent(s): db1d448

feat(core): refine pipeline & verifier; improve Spider benchmark accuracy

Browse files
adapters/llm/openai_provider.py CHANGED
@@ -1,24 +1,16 @@
1
  from __future__ import annotations
2
- import os
3
  import json
 
 
 
 
4
  from adapters.llm.base import LLMProvider
5
  from openai import OpenAI
6
 
7
- # NOTE:
8
- # - Prefer proxy if PROXY_API_KEY and PROXY_BASE_URL are set.
9
- # - Otherwise, fallback to OPENAI_API_KEY (+ OPENAI_BASE_URL defaulting to https://api.openai.com/v1).
10
- # - Do NOT pass base_url/api_key in the constructor; rely on env vars.
11
-
12
 
13
  def _resolve_api_config() -> tuple[str, str, str]:
14
- """
15
- Returns (api_key, base_url, model_id) according to env.
16
- Resolution order:
17
- 1) Proxy: PROXY_API_KEY + PROXY_BASE_URL [+ PROXY_MODEL_ID]
18
- 2) Direct: OPENAI_API_KEY [+ OPENAI_BASE_URL] [+ OPENAI_MODEL_ID]
19
- Additionally, LLM_MODEL_ID (if set) overrides model choice.
20
- """
21
- # Optional global override for model id
22
  override_model = os.getenv("LLM_MODEL_ID")
23
 
24
  proxy_key = os.getenv("PROXY_API_KEY")
@@ -43,74 +35,146 @@ def _resolve_api_config() -> tuple[str, str, str]:
43
 
44
 
45
  class OpenAIProvider(LLMProvider):
 
 
46
  provider_id = "openai"
47
 
48
  def __init__(self) -> None:
49
- # Resolve and export to env so we don't pass into constructor.
50
  api_key, base_url, model = _resolve_api_config()
51
  os.environ["OPENAI_API_KEY"] = api_key
52
  os.environ["OPENAI_BASE_URL"] = base_url
53
- # Create client using env only
54
  self.client = OpenAI()
55
  self.model = model
56
 
57
- def plan(self, *, user_query, schema_preview):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  completion = self.client.chat.completions.create(
59
  model=self.model,
60
  messages=[
61
- {"role": "system", "content": "You create SQL query plans."},
62
- {
63
- "role": "user",
64
- "content": f"Query: {user_query}\nSchema:\n{schema_preview}",
65
- },
66
  ],
67
- temperature=0,
68
  )
69
- msg = completion.choices[0].message.content
 
70
  usage = completion.usage
71
- return (
72
- msg,
73
- usage.prompt_tokens,
74
- usage.completion_tokens,
75
- self._estimate_cost(usage),
76
- )
 
 
77
 
78
  def generate_sql(
79
- self, *, user_query, schema_preview, plan_text, clarify_answers=None
80
- ):
81
- prompt = f"""
82
- You are a precise SQL generator.
83
- Return ONLY valid JSON with two keys: "sql" and "rationale".
84
- Do not include any markdown, backticks, or extra text.
85
-
86
- Example:
87
- {{
88
- "sql": "SELECT * FROM singer;",
89
- "rationale": "The user requested to list all singers."
90
- }}
91
-
92
- Now generate JSON for this input:
93
-
94
- User query: {user_query}
95
- Schema preview:
96
- {schema_preview}
97
- Plan: {plan_text}
98
- Clarifications: {clarify_answers}
99
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  completion = self.client.chat.completions.create(
101
  model=self.model,
102
  messages=[
103
- {"role": "system", "content": "You convert natural language to SQL."},
104
- {"role": "user", "content": prompt},
105
  ],
106
- temperature=0,
 
107
  )
108
- content = completion.choices[0].message.content.strip()
 
 
109
  usage = completion.usage
110
- t_in = usage.prompt_tokens if usage else None
111
- t_out = usage.completion_tokens if usage else None
112
- cost = self._estimate_cost(usage) if usage else None
113
 
 
114
  try:
115
  parsed = json.loads(content)
116
  except json.JSONDecodeError:
@@ -126,35 +190,199 @@ class OpenAIProvider(LLMProvider):
126
 
127
  sql = (parsed.get("sql") or "").strip()
128
  rationale = parsed.get("rationale") or ""
 
 
 
 
129
  if not sql:
130
  raise ValueError("LLM returned empty 'sql'")
131
 
132
- return sql, rationale, t_in, t_out, cost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- def repair(self, *, sql, error_msg, schema_preview):
135
  completion = self.client.chat.completions.create(
136
  model=self.model,
137
  messages=[
138
- {
139
- "role": "system",
140
- "content": "You fix SQL queries keeping them SELECT-only.",
141
- },
142
- {
143
- "role": "user",
144
- "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}",
145
- },
146
  ],
147
- temperature=0,
148
  )
149
- msg = completion.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  usage = completion.usage
151
- return (
152
- msg,
153
- usage.prompt_tokens,
154
- usage.completion_tokens,
155
- self._estimate_cost(usage),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
- def _estimate_cost(self, usage):
159
- total = usage.prompt_tokens + usage.completion_tokens
160
- return total * 0.000001
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+
3
  import json
4
+ import os
5
+ import re
6
+ from typing import Any, List, Tuple
7
+
8
  from adapters.llm.base import LLMProvider
9
  from openai import OpenAI
10
 
 
 
 
 
 
11
 
12
  def _resolve_api_config() -> tuple[str, str, str]:
13
+ """Returns (api_key, base_url, model_id) according to env."""
 
 
 
 
 
 
 
14
  override_model = os.getenv("LLM_MODEL_ID")
15
 
16
  proxy_key = os.getenv("PROXY_API_KEY")
 
35
 
36
 
37
  class OpenAIProvider(LLMProvider):
38
+ """OpenAI LLM provider implementation."""
39
+
40
  provider_id = "openai"
41
 
42
  def __init__(self) -> None:
43
+ """Initialize OpenAI client with config from environment."""
44
  api_key, base_url, model = _resolve_api_config()
45
  os.environ["OPENAI_API_KEY"] = api_key
46
  os.environ["OPENAI_BASE_URL"] = base_url
 
47
  self.client = OpenAI()
48
  self.model = model
49
 
50
+ def plan(
51
+ self, *, user_query: str, schema_preview: str
52
+ ) -> Tuple[str, int, int, float]:
53
+ """Generate a query plan for the SQL generation.
54
+
55
+ Args:
56
+ user_query: The user's natural language question
57
+ schema_preview: Database schema information
58
+
59
+ Returns:
60
+ Tuple of (plan_text, prompt_tokens, completion_tokens, cost)
61
+ """
62
+ system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan.
63
+
64
+ Your plan should:
65
+ 1. Identify the tables and columns needed
66
+ 2. Determine any JOINs required
67
+ 3. Specify filtering conditions (WHERE)
68
+ 4. Identify aggregations (GROUP BY, COUNT, etc.)
69
+ 5. Note sorting requirements (ORDER BY)
70
+ 6. Check for special cases (DISTINCT, LIMIT, etc.)
71
+
72
+ Be concise but thorough."""
73
+
74
+ user_prompt = f"""Question: {user_query}
75
+
76
+ Database Schema:
77
+ {schema_preview}
78
+
79
+ Create a step-by-step plan to answer this question with SQL."""
80
+
81
  completion = self.client.chat.completions.create(
82
  model=self.model,
83
  messages=[
84
+ {"role": "system", "content": system_prompt},
85
+ {"role": "user", "content": user_prompt},
 
 
 
86
  ],
87
+ temperature=0.1,
88
  )
89
+
90
+ msg = completion.choices[0].message.content or ""
91
  usage = completion.usage
92
+
93
+ if usage:
94
+ prompt_tokens = usage.prompt_tokens
95
+ completion_tokens = usage.completion_tokens
96
+ cost = self._estimate_cost(usage)
97
+ return (msg, prompt_tokens, completion_tokens, cost)
98
+ else:
99
+ return (msg, 0, 0, 0.0)
100
 
101
  def generate_sql(
102
+ self,
103
+ *,
104
+ user_query: str,
105
+ schema_preview: str,
106
+ plan_text: str,
107
+ clarify_answers: dict[str, Any] | None = None,
108
+ ) -> Tuple[str, str, int, int, float]:
109
+ """Generate SQL with improved prompt for Spider benchmark.
110
+
111
+ Args:
112
+ user_query: The user's natural language question
113
+ schema_preview: Database schema information
114
+ plan_text: Query execution plan
115
+ clarify_answers: Optional additional context
116
+
117
+ Returns:
118
+ Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
 
 
 
119
  """
120
+ system_prompt = """You are an expert SQL query generator for SQLite databases.
121
+ You must follow these STRICT rules to generate clean, simple SQL:
122
+
123
+ CRITICAL RULES:
124
+ 1. Write the SIMPLEST possible SQL that answers the question
125
+ 2. NEVER use table prefixes unless absolutely necessary for disambiguation
126
+ 3. NEVER add aliases (AS) unless specifically requested
127
+ 4. NEVER add LIMIT unless the question asks for a specific number of results
128
+ 5. NEVER use DISTINCT with COUNT(*) unless explicitly needed
129
+ 6. Use lowercase for SQL keywords (select, from, where, etc.)
130
+ 7. Do not add unnecessary parentheses or formatting
131
+ 8. Match exact column and table names from the schema (case-sensitive)
132
+
133
+ IMPORTANT:
134
+ - For counting all rows: Use COUNT(*) not COUNT(column_name)
135
+ - For ordering: Only add ORDER BY if the question asks for sorted results
136
+ - Keep the SQL as close as possible to the minimal required syntax
137
+
138
+ You must return ONLY valid JSON with exactly two keys: "sql" and "rationale".
139
+ The SQL should be a single line without unnecessary spaces."""
140
+
141
+ user_prompt = f"""Based on this information, generate a simple SQL query:
142
+
143
+ Question: {user_query}
144
+
145
+ Database Schema:
146
+ {schema_preview}
147
+
148
+ Query Plan:
149
+ {plan_text}
150
+
151
+ Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses.
152
+
153
+ Example of what we want:
154
+ Question: "How many singers are there?"
155
+ Correct: {{"sql": "select count(*) from singer", "rationale": "Count all rows in singer table"}}
156
+ Wrong: {{"sql": "SELECT COUNT(singer.singer_id) AS total_singers FROM singer", "rationale": "..."}}
157
+
158
+ Now generate the SQL for the given question:"""
159
+
160
+ if clarify_answers:
161
+ user_prompt += f"\n\nAdditional context: {clarify_answers}"
162
+
163
  completion = self.client.chat.completions.create(
164
  model=self.model,
165
  messages=[
166
+ {"role": "system", "content": system_prompt},
167
+ {"role": "user", "content": user_prompt},
168
  ],
169
+ temperature=0.1,
170
+ max_tokens=500,
171
  )
172
+
173
+ text = completion.choices[0].message.content
174
+ content = text.strip() if text else ""
175
  usage = completion.usage
 
 
 
176
 
177
+ # Parse JSON response
178
  try:
179
  parsed = json.loads(content)
180
  except json.JSONDecodeError:
 
190
 
191
  sql = (parsed.get("sql") or "").strip()
192
  rationale = parsed.get("rationale") or ""
193
+
194
+ # Post-process SQL to ensure simplicity
195
+ sql = self._simplify_sql(sql)
196
+
197
  if not sql:
198
  raise ValueError("LLM returned empty 'sql'")
199
 
200
+ if usage:
201
+ prompt_tokens = usage.prompt_tokens
202
+ completion_tokens = usage.completion_tokens
203
+ cost = self._estimate_cost(usage)
204
+ return (sql, rationale, prompt_tokens, completion_tokens, cost)
205
+ else:
206
+ return (sql, rationale, 0, 0, 0.0)
207
+
208
+ def _simplify_sql(self, sql: str) -> str:
209
+ """Post-process SQL to remove common unnecessary additions."""
210
+ if not sql:
211
+ return sql
212
+
213
+ # Remove trailing semicolon
214
+ sql = sql.rstrip(";")
215
+
216
+ # Remove unnecessary table prefixes in simple queries
217
+ # e.g., "singer.name" -> "name" when there's only one table
218
+ if sql.lower().count(" from ") == 1 and " join " not in sql.lower():
219
+ match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE)
220
+ if match:
221
+ table = match.group(1)
222
+ sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql)
223
+
224
+ # Remove unnecessary DISTINCT in COUNT(*)
225
+ sql = re.sub(
226
+ r"count\s*\(\s*distinct\s+\*\s*\)",
227
+ "count(*)",
228
+ sql,
229
+ flags=re.IGNORECASE,
230
+ )
231
+
232
+ # Remove big default LIMITs that weren't requested
233
+ sql = re.sub(
234
+ r"\s+limit\s+(100|1000|10000)\b",
235
+ "",
236
+ sql,
237
+ flags=re.IGNORECASE,
238
+ )
239
+
240
+ return sql
241
+
242
+ def repair(
243
+ self,
244
+ *,
245
+ sql: str,
246
+ error_msg: str,
247
+ schema_preview: str,
248
+ ) -> Tuple[str, int, int, float]:
249
+ """Repair SQL with focus on simplicity.
250
+
251
+ Args:
252
+ sql: Broken SQL query
253
+ error_msg: Error message from execution
254
+ schema_preview: Database schema information
255
+
256
+ Returns:
257
+ Tuple of (fixed_sql, prompt_tokens, completion_tokens, cost)
258
+ """
259
+ system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error.
260
+
261
+ IMPORTANT RULES:
262
+ 1. Keep the fix as minimal as possible
263
+ 2. Don't add complexity - keep it simple
264
+ 3. Preserve the original intent of the query
265
+ 4. Follow SQLite syntax rules
266
+ 5. Don't add aliases or table prefixes unless necessary
267
+
268
+ Return ONLY the corrected SQL query, nothing else."""
269
+
270
+ user_prompt = f"""Fix this SQL query:
271
+
272
+ Original SQL: {sql}
273
+
274
+ Error: {error_msg}
275
+
276
+ Database Schema:
277
+ {schema_preview}
278
+
279
+ Return the corrected SQL (keep it simple):"""
280
 
 
281
  completion = self.client.chat.completions.create(
282
  model=self.model,
283
  messages=[
284
+ {"role": "system", "content": system_prompt},
285
+ {"role": "user", "content": user_prompt},
 
 
 
 
 
 
286
  ],
287
+ temperature=0.1,
288
  )
289
+
290
+ text = completion.choices[0].message.content
291
+ fixed_sql = text.strip() if text else ""
292
+
293
+ # Clean up accidental code fences
294
+ if fixed_sql.startswith("```sql"):
295
+ fixed_sql = fixed_sql[6:]
296
+ if fixed_sql.startswith("```"):
297
+ fixed_sql = fixed_sql[3:]
298
+ if fixed_sql.endswith("```"):
299
+ fixed_sql = fixed_sql[:-3]
300
+
301
+ fixed_sql = fixed_sql.strip()
302
+ fixed_sql = self._simplify_sql(fixed_sql)
303
+
304
  usage = completion.usage
305
+
306
+ if usage:
307
+ prompt_tokens = usage.prompt_tokens
308
+ completion_tokens = usage.completion_tokens
309
+ cost = self._estimate_cost(usage)
310
+ return (fixed_sql, prompt_tokens, completion_tokens, cost)
311
+ else:
312
+ return (fixed_sql, 0, 0, 0.0)
313
+
314
+ def _estimate_cost(self, usage: Any) -> float:
315
+ """Estimate cost based on token usage.
316
+
317
+ Args:
318
+ usage: OpenAI usage object with token counts
319
+
320
+ Returns:
321
+ Estimated cost in USD
322
+ """
323
+ if not usage:
324
+ return 0.0
325
+
326
+ # Pricing per 1K tokens (adjust based on model)
327
+ pricing = {
328
+ "gpt-4": {"input": 0.03, "output": 0.06},
329
+ "gpt-4-turbo": {"input": 0.01, "output": 0.03},
330
+ "gpt-4o": {"input": 0.005, "output": 0.015},
331
+ "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
332
+ "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
333
+ }
334
+
335
+ model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"])
336
+
337
+ input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"]
338
+ output_cost = (usage.completion_tokens / 1000) * model_pricing["output"]
339
+
340
+ return input_cost + output_cost
341
+
342
+ def clarify(
343
+ self,
344
+ *,
345
+ user_query: str,
346
+ schema_preview: str,
347
+ questions: List[str],
348
+ ) -> Tuple[str, int, int, float]:
349
+ """Clarify ambiguities in the user query.
350
+
351
+ Args:
352
+ user_query: The user's natural language question
353
+ schema_preview: Database schema information
354
+ questions: List of clarification questions
355
+
356
+ Returns:
357
+ Tuple of (answers, prompt_tokens, completion_tokens, cost)
358
+ """
359
+ system_prompt = """You are a helpful assistant that clarifies SQL query requirements.
360
+ Answer the questions clearly and concisely based on the user's query and database schema."""
361
+
362
+ user_prompt = f"""User Query: {user_query}
363
+
364
+ Database Schema:
365
+ {schema_preview}
366
+
367
+ Please answer these clarification questions:
368
+ {chr(10).join(f"{i + 1}. {q}" for i, q in enumerate(questions))}"""
369
+
370
+ completion = self.client.chat.completions.create(
371
+ model=self.model,
372
+ messages=[
373
+ {"role": "system", "content": system_prompt},
374
+ {"role": "user", "content": user_prompt},
375
+ ],
376
+ temperature=0.3,
377
  )
378
 
379
+ answers = completion.choices[0].message.content or ""
380
+ usage = completion.usage
381
+
382
+ if usage:
383
+ prompt_tokens = usage.prompt_tokens
384
+ completion_tokens = usage.completion_tokens
385
+ cost = self._estimate_cost(usage)
386
+ return (answers, prompt_tokens, completion_tokens, cost)
387
+ else:
388
+ return (answers, 0, 0, 0.0)
benchmarks/evaluate_spider_pro.py CHANGED
@@ -1,490 +1,446 @@
 
1
  """
2
- Pro evaluation runner with two modes:
3
- Extension of `evaluate_spider.py` with additional metrics (EM, SM, ExecAcc) and richer logging for research-style benchmarking.
4
-
5
- 1) Single-DB demo mode (default)
6
- - Runs a list of questions against one SQLite DB
7
- - Reports latency/ok (no EM/SM/ExecAcc because there's no gold SQL)
8
-
9
- 2) Spider mode (--spider)
10
- - Loads a subset of the Spider dataset via SPIDER_ROOT
11
- - For each item, builds a per-DB pipeline and computes:
12
- * EM (exact SQL string match, case-insensitive)
13
- * SM (structural match via sqlglot AST)
14
- * ExecAcc (result equivalence by executing gold vs. predicted SQL)
15
- - Also logs latency, (optional) traces, and aggregates a summary
16
-
17
- Works with:
18
- - Real LLM (OPENAI_API_KEY set)
19
- - Stub mode (PYTEST_CURRENT_TEST=1) for zero-cost offline runs
20
-
21
- Outputs:
22
- benchmarks/results_pro/<timestamp>/
23
- - eval.jsonl # per-sample rows
24
- - summary.json # aggregate metrics
25
- - results.csv # human-friendly table
26
-
27
- Examples:
28
- # Demo (single DB), stub mode
29
- PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
30
- python benchmarks/evaluate_spider_pro.py --db-path demo.db
31
-
32
- # Spider subset (20 items), stub mode
33
- export SPIDER_ROOT=$PWD/data/spider
34
- PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
35
- python benchmarks/evaluate_spider_pro.py --spider --split dev --limit 20
36
  """
37
 
38
  from __future__ import annotations
39
 
40
  import argparse
41
- import csv
42
  import json
43
- import os
 
44
  import time
 
 
45
  from pathlib import Path
46
- from typing import Any, Dict, List, Optional
47
-
48
- import sqlglot
49
- from sqlglot.errors import ParseError
50
 
51
  from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
52
  from adapters.db.sqlite_adapter import SQLiteAdapter
 
 
 
53
 
54
- # Only needed for Spider mode
55
- try:
56
- from benchmarks.spider_loader import load_spider_sqlite, open_readonly_connection
57
- except Exception:
58
- load_spider_sqlite = None # type: ignore[assignment]
59
- open_readonly_connection = None # type: ignore[assignment]
60
-
61
- # Resolve repo root and default config path relative to this file (not CWD)
62
- THIS_DIR = Path(__file__).resolve().parent # .../benchmarks
63
- REPO_ROOT = THIS_DIR.parent # repo root
64
- CONFIG_PATH = str(REPO_ROOT / "configs" / "sqlite_pipeline.yaml")
65
-
66
-
67
- # Default demo questions for single-DB mode
68
- DEFAULT_DATASET: List[str] = [
69
- "list all customers",
70
- "show total invoices per country",
71
- "top 3 albums by total sales",
72
- "artists with more than 3 albums",
73
- "number of employees per city",
74
- ]
75
-
76
- RESULT_ROOT = Path("benchmarks") / "results_pro"
77
  TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
78
  RESULT_DIR = RESULT_ROOT / TIMESTAMP
79
 
80
 
81
- # -------------------- Utilities --------------------
82
 
83
 
84
- def _int_ms(start: float) -> int:
85
- """Convert elapsed seconds to integer milliseconds."""
86
- return int((time.perf_counter() - start) * 1000)
 
87
 
 
 
 
88
 
89
- def _derive_schema_preview_safe(pipeline_obj: Any) -> Optional[str]:
90
- """Safely call derive_schema_preview() if available on adapter/executor."""
91
- try:
92
- for c in (
93
- getattr(pipeline_obj, "executor", None),
94
- getattr(pipeline_obj, "adapter", None),
95
- ):
96
- if c and hasattr(c, "derive_schema_preview"):
97
- return c.derive_schema_preview() # type: ignore[no-any-return]
98
- except Exception:
99
- pass
100
- return None
101
-
102
-
103
- def _to_stage_list(trace_obj: Any) -> List[Dict[str, Any]]:
104
- """Normalize pipeline trace into a list of dicts for logging/export."""
105
- out: List[Dict[str, Any]] = []
106
- if not isinstance(trace_obj, list):
107
- return out
108
- for t in trace_obj:
109
- if isinstance(t, dict):
110
- stage = t.get("stage", "?")
111
- ms = t.get("duration_ms", 0)
112
- else:
113
- stage = getattr(t, "stage", "?")
114
- ms = getattr(t, "duration_ms", 0)
115
- try:
116
- out.append({"stage": str(stage), "ms": int(ms)})
117
- except Exception:
118
- out.append({"stage": str(stage), "ms": 0})
119
- return out
120
 
 
 
121
 
122
- def _parse_sql(sql: str):
123
- try:
124
- return sqlglot.parse_one(sql, read="sqlite")
125
- except ParseError:
126
- return None
127
-
128
-
129
- def _structural_match(pred: str, gold: str) -> bool:
130
- """AST-level equality via sqlglot; returns False if either side can't be parsed."""
131
- a, b = _parse_sql(pred), _parse_sql(gold)
132
- return (a == b) if (a is not None and b is not None) else False
133
-
134
-
135
- def _load_dataset_from_file(path: Optional[str]) -> List[str]:
136
- """Load questions from a JSON file: list[str] or list[{question: str}]."""
137
- if not path:
138
- return DEFAULT_DATASET
139
- p = Path(path)
140
- if not p.exists():
141
- raise FileNotFoundError(f"dataset file not found: {p}")
142
- data = json.loads(p.read_text(encoding="utf-8"))
143
- if isinstance(data, list):
144
- if all(isinstance(x, str) for x in data):
145
- return list(data)
146
- if all(isinstance(x, dict) and "question" in x for x in data):
147
- return [str(x["question"]) for x in data]
148
- raise ValueError(
149
- "Dataset file must be a JSON array of strings or objects with 'question' field."
150
  )
 
 
 
 
151
 
152
 
153
- def _extract_sql(result: Any) -> str:
154
- """
155
- Extract SQL from pipeline result in a mypy-friendly way.
156
- Supports both result.sql and result.data.sql shapes.
157
- """
158
- sql_pred: Optional[str] = getattr(result, "sql", None)
159
- if not sql_pred:
160
- data = getattr(result, "data", None)
161
- if data is not None:
162
- sql_pred = getattr(data, "sql", None)
163
- return (sql_pred or "").strip()
164
 
 
 
 
 
 
165
 
166
- def _save_outputs(rows: List[Dict[str, Any]], summary: Dict[str, Any]) -> None:
167
- """Persist JSONL + JSON summary + CSV for pro runner."""
168
- RESULT_DIR.mkdir(parents=True, exist_ok=True)
169
 
170
- jsonl_path = RESULT_DIR / "eval.jsonl"
171
- with jsonl_path.open("w", encoding="utf-8") as f:
172
- for r in rows:
173
- f.write(json.dumps(r, ensure_ascii=False) + "\n")
174
-
175
- with (RESULT_DIR / "summary.json").open("w", encoding="utf-8") as f:
176
- json.dump(summary, f, indent=2)
177
-
178
- csv_path = RESULT_DIR / "results.csv"
179
- # For pro, include pro columns when present (Spider mode)
180
- fieldnames = [
181
- "source",
182
- "db_id",
183
- "query",
184
- "em",
185
- "sm",
186
- "exec_acc",
187
- "ok",
188
- "latency_ms",
189
- ]
190
- with csv_path.open("w", newline="", encoding="utf-8") as f:
191
- wr = csv.DictWriter(f, fieldnames=fieldnames)
192
- wr.writeheader()
193
- for r in rows:
194
- wr.writerow(
195
- {
196
- "source": r.get("source", "demo"),
197
- "db_id": r.get("db_id", ""),
198
- "query": r.get("query", ""),
199
- "em": "✅" if r.get("em") else "❌" if "em" in r else "",
200
- "sm": "✅" if r.get("sm") else "❌" if "sm" in r else "",
201
- "exec_acc": "✅"
202
- if r.get("exec_acc")
203
- else "❌"
204
- if "exec_acc" in r
205
- else "",
206
- "ok": "✅" if r.get("ok") else "❌",
207
- "latency_ms": int(r.get("latency_ms", 0)),
208
- }
209
- )
210
 
211
- print(
212
- "\n💾 Saved outputs:\n"
213
- f"- {jsonl_path}\n- {RESULT_DIR / 'summary.json'}\n- {csv_path}\n"
214
- f"📊 Avg latency: {summary.get('avg_latency_ms', 0.0)} ms "
215
- f"| EM: {summary.get('EM', 0.0):.3f} "
216
- f"| SM: {summary.get('SM', 0.0):.3f} "
217
- f"| ExecAcc: {summary.get('ExecAcc', 0.0):.3f} "
218
- f"| Success: {summary.get('success_rate', 0.0):.0%}\n"
219
- )
220
 
 
 
221
 
222
- # -------------------- Runners --------------------
 
223
 
 
 
 
224
 
225
- def _run_single_db_mode(db_path: Path, questions: List[str], config_path: str) -> None:
226
- """
227
- Single-DB demo mode.
228
- Only latency/ok is reported (no EM/SM/ExecAcc, because we don't have gold SQL).
229
- """
230
- adapter = SQLiteAdapter(str(db_path))
231
- pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
232
 
233
- schema_preview = _derive_schema_preview_safe(pipeline)
234
- if schema_preview:
235
- print("📄 Derived schema preview ✓")
236
- else:
237
- print("ℹ️ No schema preview (adapter does not expose it or not needed)")
238
 
239
- rows: List[Dict[str, Any]] = []
240
- for q in questions:
241
- print(f"\n🧠 Query: {q}")
242
- t0 = time.perf_counter()
243
- try:
244
- result = pipeline.run(user_query=q, schema_preview=schema_preview or "")
245
- latency_ms = _int_ms(t0) or 1 # clamp to 1ms for nicer CSV in stub mode
246
- stages = _to_stage_list(
247
- getattr(result, "traces", getattr(result, "trace", []))
248
- )
249
- rows.append(
250
- {
251
- "source": "demo",
252
- "db_id": Path(db_path).stem,
253
- "query": q,
254
- "ok": bool(getattr(result, "ok", True)),
255
- "latency_ms": latency_ms,
256
- "trace": stages,
257
- "error": None,
258
- }
259
- )
260
- print(f"✅ Success ({latency_ms} ms)")
261
- except Exception as exc:
262
- latency_ms = _int_ms(t0) or 1
263
- rows.append(
264
- {
265
- "source": "demo",
266
- "db_id": Path(db_path).stem,
267
- "query": q,
268
- "ok": False,
269
- "latency_ms": latency_ms,
270
- "trace": [],
271
- "error": str(exc),
272
- }
273
- )
274
- print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
275
 
276
- success_rate = (
277
- (sum(1 for r in rows if r.get("ok")) / max(len(rows), 1)) if rows else 0.0
278
- )
279
- avg_latency = (
280
- round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(len(rows), 1), 1)
281
- if rows
282
- else 0.0
283
- )
284
- summary = {
285
- "mode": "single-db",
286
- "db_path": str(db_path),
287
- "config": config_path,
288
- "provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
289
- "total": len(rows),
290
- "EM": 0.0,
291
- "SM": 0.0,
292
- "ExecAcc": 0.0, # not applicable in demo
293
- "success_rate": success_rate,
294
- "avg_latency_ms": avg_latency,
295
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
296
- }
297
- _save_outputs(rows, summary)
298
 
 
 
 
 
 
 
 
 
 
299
 
300
- def _run_spider_mode(split: str, limit: int, config_path: str) -> None:
301
- """
302
- Spider mode: compute EM/SM/ExecAcc with per-DB pipelines.
303
- Requires SPIDER_ROOT pointing to a folder that contains dev.json/train_spider.json and database/.
304
- """
305
- if load_spider_sqlite is None or open_readonly_connection is None:
306
- raise RuntimeError(
307
- "Spider utilities are not available. Ensure benchmarks/spider_loader.py exists."
308
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
- items = load_spider_sqlite(split=split, limit=limit)
311
- print(f"🗂 Loaded {len(items)} Spider items (split={split}).")
 
312
 
313
- rows: List[Dict[str, Any]] = []
 
 
 
 
 
 
 
 
314
 
315
- for i, ex in enumerate(items, 1):
316
- print(f"\n[{i}] {ex.db_id} :: {ex.question}")
317
- adapter = SQLiteAdapter(ex.db_path)
318
- pipeline = pipeline_from_config_with_adapter(config_path, adapter=adapter)
319
 
320
- # Optional schema preview per DB
321
- schema_preview = _derive_schema_preview_safe(pipeline)
322
 
323
- # Open read-only connection for ExecAcc computation
324
- conn = open_readonly_connection(ex.db_path)
325
 
326
- t0 = time.perf_counter()
327
- try:
328
- result = pipeline.run(
329
- user_query=ex.question, schema_preview=schema_preview or ""
330
- )
331
- latency_ms = _int_ms(t0) or 1
332
- stages = _to_stage_list(
333
- getattr(result, "traces", getattr(result, "trace", []))
334
- )
335
 
336
- # Extract predicted SQL from result (support both .sql and .data.sql)
337
- sql_pred = _extract_sql(result)
338
-
339
- # Pro metrics
340
- gold_sql = ex.gold_sql.strip()
341
- em = (sql_pred.lower() == gold_sql.lower()) if sql_pred else False
342
- sm = _structural_match(sql_pred, gold_sql) if sql_pred else False
343
-
344
- try:
345
- gold_exec = conn.execute(gold_sql).fetchall()
346
- except Exception:
347
- gold_exec = []
348
- try:
349
- pred_exec = conn.execute(sql_pred).fetchall() if sql_pred else []
350
- except Exception:
351
- pred_exec = []
352
- exec_acc = gold_exec == pred_exec
353
-
354
- rows.append(
355
- {
356
- "source": "spider",
357
- "db_id": ex.db_id,
358
- "query": ex.question,
359
- "sql_pred": sql_pred,
360
- "sql_gold": gold_sql,
361
- "em": em,
362
- "sm": sm,
363
- "exec_acc": exec_acc,
364
- "ok": bool(getattr(result, "ok", True)),
365
- "latency_ms": latency_ms,
366
- "trace": stages,
367
- "error": None,
368
- }
369
- )
370
- print(f"✅ OK | EM={em} | SM={sm} | Exec={exec_acc} | {latency_ms} ms")
371
- except Exception as exc:
372
- latency_ms = _int_ms(t0) or 1
373
- rows.append(
374
- {
375
- "source": "spider",
376
- "db_id": ex.db_id,
377
- "query": ex.question,
378
- "sql_pred": None,
379
- "sql_gold": ex.gold_sql,
380
- "em": False,
381
- "sm": False,
382
- "exec_acc": False,
383
- "ok": False,
384
- "latency_ms": latency_ms,
385
- "trace": [],
386
- "error": str(exc),
387
- }
388
- )
389
- print(f"❌ Fail: {exc!s} ({latency_ms} ms)")
390
- finally:
391
- try:
392
- conn.close()
393
- except Exception:
394
- pass
395
-
396
- # Aggregate pro metrics
397
- total = len(rows)
398
- em_rate = (sum(1 for r in rows if r.get("em")) / max(total, 1)) if rows else 0.0
399
- sm_rate = (sum(1 for r in rows if r.get("sm")) / max(total, 1)) if rows else 0.0
400
- exec_rate = (
401
- (sum(1 for r in rows if r.get("exec_acc")) / max(total, 1)) if rows else 0.0
402
- )
403
- success_rate = (
404
- (sum(1 for r in rows if r.get("ok")) / max(total, 1)) if rows else 0.0
405
- )
406
- avg_latency = (
407
- round(sum(int(r.get("latency_ms", 0)) for r in rows) / max(total, 1), 1)
408
- if rows
409
- else 0.0
410
- )
411
 
412
- summary = {
413
- "mode": "spider",
414
- "split": split,
415
- "limit": limit,
416
- "config": config_path,
417
- "provider_hint": ("STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL"),
418
- "spider_root": os.getenv("SPIDER_ROOT", ""),
419
- "total": total,
420
- "EM": round(em_rate, 3),
421
- "SM": round(sm_rate, 3),
422
- "ExecAcc": round(exec_rate, 3),
423
- "success_rate": success_rate,
424
- "avg_latency_ms": avg_latency,
425
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
426
- }
427
- _save_outputs(rows, summary)
428
 
 
 
429
 
430
- # -------------------- CLI --------------------
 
 
 
 
 
 
431
 
 
432
 
433
- def main() -> None:
434
- ap = argparse.ArgumentParser()
435
- ap.add_argument(
436
- "--spider",
437
- action="store_true",
438
- help="Enable Spider mode (reads from SPIDER_ROOT; ignores --db-path).",
439
- )
440
- ap.add_argument(
441
- "--split",
442
- type=str,
443
- default="dev",
444
- choices=["dev", "train"],
445
- help="Spider split to use (default: dev).",
446
- )
447
- ap.add_argument(
448
- "--limit",
449
- type=int,
450
- default=20,
451
- help="Number of Spider items to evaluate (default: 20).",
452
- )
453
 
454
- ap.add_argument(
455
- "--db-path",
456
- type=str,
457
- default="demo.db",
458
- help="Path to SQLite database file (single-DB mode).",
459
- )
460
- ap.add_argument(
461
- "--dataset-file",
462
- type=str,
463
- default=None,
464
- help="Optional JSON file with questions (single-DB mode).",
465
- )
466
- ap.add_argument(
467
- "--config",
468
- type=str,
469
- default=CONFIG_PATH,
470
- help=f"Pipeline YAML config (default: {CONFIG_PATH})",
471
- )
472
- args = ap.parse_args()
473
 
474
- if args.spider:
475
- if not os.getenv("SPIDER_ROOT"):
476
- raise RuntimeError(
477
- "SPIDER_ROOT is not set. It must point to the folder that directly contains "
478
- "dev.json/train_spider.json and the database/ directory."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  )
480
- _run_spider_mode(args.split, args.limit, args.config)
481
- else:
482
- db_path = Path(args.db_path).resolve()
483
- if not db_path.exists():
484
- raise FileNotFoundError(f"SQLite DB not found: {db_path}")
485
- questions = _load_dataset_from_file(args.dataset_file)
486
- _run_single_db_mode(db_path, questions, args.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
 
489
  if __name__ == "__main__":
 
490
  main()
 
1
+ #!/usr/bin/env python3
2
  """
3
+ Enhanced Spider benchmark evaluator for NL2SQL pipeline.
4
+ No external dependencies - uses internal evaluation logic.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
8
 
9
  import argparse
 
10
  import json
11
+ import re
12
+ import sqlite3
13
  import time
14
+ from dataclasses import dataclass
15
+ from datetime import datetime
16
  from pathlib import Path
17
+ from typing import Any, Dict, List, Tuple
 
 
 
18
 
19
  from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
20
  from adapters.db.sqlite_adapter import SQLiteAdapter
21
+ from benchmarks.spider_loader import load_spider_sqlite
22
+
23
+ # ==================== Configuration ====================
24
 
25
+ RESULT_ROOT = Path("benchmarks/results_pro")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
27
  RESULT_DIR = RESULT_ROOT / TIMESTAMP
28
 
29
 
30
+ # ==================== SQL Processing ====================
31
 
32
 
33
+ def extract_clean_sql(text: str | None) -> str:
34
+ """Safely extract a clean SQL string from input text possibly containing markdown fences or JSON."""
35
+ # Always initialize variable to empty string
36
+ sql = text or ""
37
 
38
+ # Remove markdown code fences
39
+ sql = re.sub(r"```(?:sql)?\s*\n?", "", sql, flags=re.IGNORECASE)
40
+ sql = re.sub(r"```\s*$", "", sql)
41
 
42
+ # Try JSON pattern like {"sql": "..."}
43
+ m_json = re.search(r'"sql"\s*:\s*"([^"]+)"', sql)
44
+ if m_json:
45
+ sql = m_json.group(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Clean escaped characters
48
+ sql = sql.replace('\\"', '"').replace("\\n", " ").replace("\\t", " ")
49
 
50
+ # Try to locate SQL statement keywords
51
+ m_sql = re.search(
52
+ r"\b(select|with|insert|update|delete)\b[\s\S]+", sql, re.IGNORECASE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
+ if m_sql:
55
+ sql = m_sql.group(0)
56
+ sql = re.sub(r"\s+", " ", sql).strip().rstrip(";")
57
+ return sql
58
 
59
 
60
+ def normalize_sql(sql: str) -> str:
61
+ """Enhanced SQL normalization for better matching."""
62
+ if not sql:
63
+ return ""
 
 
 
 
 
 
 
64
 
65
+ sql = sql.strip().upper()
66
+ # Remove all whitespace variations
67
+ sql = re.sub(r"\s+", " ", sql)
68
+ # Remove trailing semicolon
69
+ sql = sql.rstrip(";")
70
 
71
+ # Remove table prefixes (e.g., singer.name -> name)
72
+ sql = re.sub(r"\b\w+\.(\w+)\b", r"\1", sql)
 
73
 
74
+ # Remove AS aliases
75
+ sql = re.sub(r"\s+AS\s+\w+", "", sql, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Remove DISTINCT if used with COUNT(*)
78
+ sql = re.sub(r"COUNT\s*\(\s*DISTINCT\s+", "COUNT(", sql)
 
 
 
 
 
 
 
79
 
80
+ # Normalize COUNT variations
81
+ sql = re.sub(r"COUNT\s*\(\s*\w+\s*\)", "COUNT(*)", sql)
82
 
83
+ # Remove LIMIT at end
84
+ sql = re.sub(r"\s+LIMIT\s+\d+$", "", sql)
85
 
86
+ # Normalize quotes
87
+ sql = re.sub(r'"(\w+)"', r"\1", sql)
88
+ sql = re.sub(r"`(\w+)`", r"\1", sql)
89
 
90
+ return sql
 
 
 
 
 
 
91
 
 
 
 
 
 
92
 
93
+ # ==================== Schema Extraction ====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ def get_database_schema(db_path: Path) -> Dict[str, Any]:
97
+ """Extract complete schema from SQLite database."""
98
+ if not db_path.exists():
99
+ return {}
100
+
101
+ conn = sqlite3.connect(str(db_path))
102
+ cursor = conn.cursor()
103
+
104
+ schema: dict[str, Any] = {"tables": {}}
105
 
106
+ try:
107
+ # Get all tables
108
+ cursor.execute(
109
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
 
 
 
 
110
  )
111
+ tables = cursor.fetchall()
112
+
113
+ for (table_name,) in tables:
114
+ # Get columns
115
+ cursor.execute(f"PRAGMA table_info('{table_name}')")
116
+ columns = cursor.fetchall()
117
+
118
+ col_info = []
119
+ for col in columns:
120
+ col_name = col[1]
121
+ col_type = col[2]
122
+ is_pk = col[5]
123
+
124
+ col_dict = {
125
+ "name": col_name,
126
+ "type": col_type,
127
+ "primary_key": bool(is_pk),
128
+ }
129
+ col_info.append(col_dict)
130
 
131
+ # Get foreign keys
132
+ cursor.execute(f"PRAGMA foreign_key_list('{table_name}')")
133
+ fks = cursor.fetchall()
134
 
135
+ fk_info = []
136
+ for fk in fks:
137
+ fk_info.append(
138
+ {
139
+ "column": fk[3],
140
+ "referenced_table": fk[2],
141
+ "referenced_column": fk[4],
142
+ }
143
+ )
144
 
145
+ schema["tables"][table_name] = {
146
+ "columns": col_info,
147
+ "foreign_keys": fk_info,
148
+ }
149
 
150
+ finally:
151
+ conn.close()
152
 
153
+ return schema
 
154
 
 
 
 
 
 
 
 
 
 
155
 
156
+ def format_schema_for_prompt(schema: Dict[str, Any]) -> str:
157
+ """Format schema for LLM prompt."""
158
+ if not schema or not schema.get("tables"):
159
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ lines = []
162
+ for table_name, table_info in schema["tables"].items():
163
+ cols = []
164
+ for col in table_info["columns"]:
165
+ col_str = f"{col['name']} {col['type']}"
166
+ if col.get("primary_key"):
167
+ col_str += " PRIMARY KEY"
168
+ cols.append(col_str)
 
 
 
 
 
 
 
 
169
 
170
+ lines.append(f"Table: {table_name}")
171
+ lines.append(f"Columns: {', '.join(cols)}")
172
 
173
+ if table_info.get("foreign_keys"):
174
+ fks = []
175
+ for fk in table_info["foreign_keys"]:
176
+ fks.append(
177
+ f"{fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}"
178
+ )
179
+ lines.append(f"Foreign Keys: {', '.join(fks)}")
180
 
181
+ lines.append("") # Empty line between tables
182
 
183
+ return "\n".join(lines).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # ==================== SQL Evaluation ====================
187
+
188
+
189
+ def execute_sql(db_path: Path, sql: str) -> Tuple[bool, List[Tuple]]:
190
+ """Execute SQL and return success flag and results."""
191
+ if not sql:
192
+ return False, []
193
+
194
+ try:
195
+ conn = sqlite3.connect(str(db_path))
196
+ cursor = conn.cursor()
197
+ cursor.execute(sql)
198
+ results = cursor.fetchall()
199
+ conn.close()
200
+ return True, results
201
+ except Exception:
202
+ return False, []
203
+
204
+
205
+ def compare_sql_results(gold_results: List[Tuple], pred_results: List[Tuple]) -> bool:
206
+ """Compare SQL execution results."""
207
+ if len(gold_results) != len(pred_results):
208
+ return False
209
+
210
+ # Convert to sets for comparison (order independent)
211
+ gold_set = set(gold_results)
212
+ pred_set = set(pred_results)
213
+
214
+ return gold_set == pred_set
215
+
216
+
217
+ def evaluate_sql_match(pred_sql: str, gold_sql: str, db_path: Path) -> Dict[str, float]:
218
+ """Evaluate predicted SQL against gold SQL."""
219
+ metrics = {"exact_match": 0.0, "set_match": 0.0, "exec_accuracy": 0.0}
220
+
221
+ if not pred_sql:
222
+ return metrics
223
+
224
+ # Exact match
225
+ if normalize_sql(pred_sql) == normalize_sql(gold_sql):
226
+ metrics["exact_match"] = 1.0
227
+
228
+ # Execution-based evaluation
229
+ gold_success, gold_results = execute_sql(db_path, gold_sql)
230
+ pred_success, pred_results = execute_sql(db_path, pred_sql)
231
+
232
+ if gold_success and pred_success:
233
+ # Set match (results match)
234
+ if compare_sql_results(gold_results, pred_results):
235
+ metrics["set_match"] = 1.0
236
+ metrics["exec_accuracy"] = 1.0
237
+ else:
238
+ # Partial credit for successful execution
239
+ metrics["exec_accuracy"] = 0.5
240
+
241
+ return metrics
242
+
243
+
244
+ # ==================== Pipeline Runner ====================
245
+
246
+
247
+ @dataclass
248
+ class SpiderSample:
249
+ """Spider dataset sample."""
250
+
251
+ question: str
252
+ db_id: str
253
+ db_path: Path
254
+ gold_sql: str
255
+
256
+
257
+ def run_pipeline_on_sample(
258
+ pipeline: Any,
259
+ sample: SpiderSample,
260
+ schema_cache: Dict[str, str],
261
+ debug: bool = False,
262
+ ) -> Dict[str, Any]:
263
+ """Run NL2SQL pipeline on a single sample."""
264
+
265
+ # Get/cache schema
266
+ if sample.db_id not in schema_cache:
267
+ schema_dict = get_database_schema(sample.db_path)
268
+ schema_str = format_schema_for_prompt(schema_dict)
269
+ schema_cache[sample.db_id] = schema_str
270
+ if debug:
271
+ print(f" [schema] Loaded {len(schema_str)} chars for {sample.db_id}")
272
+
273
+ schema: str = schema_cache[sample.db_id]
274
+
275
+ # Run pipeline
276
+ try:
277
+ result = pipeline.run(user_query=sample.question, schema_preview=schema)
278
+
279
+ # Extract SQL from result
280
+ if hasattr(result, "sql") and result.sql:
281
+ pred_sql = extract_clean_sql(result.sql)
282
+ else:
283
+ # Try to extract from various fields
284
+ for attr in ["final_sql", "generated_sql", "answer"]:
285
+ if hasattr(result, attr):
286
+ val = getattr(result, attr)
287
+ if val:
288
+ pred_sql = extract_clean_sql(str(val))
289
+ if pred_sql:
290
+ break
291
+ else:
292
+ pred_sql = ""
293
+
294
+ return {
295
+ "ok": bool(getattr(result, "ok", True)),
296
+ "sql": pred_sql,
297
+ "raw_response": getattr(result, "sql", ""),
298
+ "traces": getattr(result, "traces", []),
299
+ "error": None,
300
+ }
301
+
302
+ except Exception as e:
303
+ if debug:
304
+ import traceback
305
+
306
+ traceback.print_exc()
307
+ return {
308
+ "ok": False,
309
+ "sql": "",
310
+ "raw_response": "",
311
+ "traces": [],
312
+ "error": str(e),
313
+ }
314
+
315
+
316
+ # ==================== Main Evaluation ====================
317
+
318
+
319
+ def main():
320
+ parser = argparse.ArgumentParser(description="Evaluate NL2SQL on Spider")
321
+ parser.add_argument("--spider", action="store_true", help="Run Spider evaluation")
322
+ parser.add_argument("--split", default="dev", choices=["dev", "train"])
323
+ parser.add_argument("--limit", type=int, help="Limit number of samples")
324
+ parser.add_argument("--debug", action="store_true", help="Enable debug output")
325
+ parser.add_argument("--config", default="configs/sqlite_pipeline.yaml")
326
+
327
+ args = parser.parse_args()
328
+
329
+ if not args.spider:
330
+ print("Please use --spider flag to run Spider evaluation")
331
+ return
332
+
333
+ # Load Spider samples
334
+ print(f"Loading Spider {args.split} split...")
335
+ samples = load_spider_sqlite(split=args.split, limit=args.limit)
336
+
337
+ if not samples:
338
+ print("❌ No samples loaded. Check SPIDER_ROOT environment variable.")
339
+ return
340
+
341
+ print(f"✔ Loaded {len(samples)} samples")
342
+
343
+ # Prepare results directory
344
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
345
+
346
+ # Initialize schema cache
347
+ schema_cache = {}
348
+
349
+ # Process each sample
350
+ results = []
351
+ for i, spider_item in enumerate(samples, 1):
352
+ # Convert to our sample format
353
+ sample = SpiderSample(
354
+ question=spider_item.question,
355
+ db_id=spider_item.db_id,
356
+ db_path=Path(spider_item.db_path),
357
+ gold_sql=spider_item.gold_sql,
358
+ )
359
+
360
+ print(f"\n🧠 [{i}/{len(samples)}] [{sample.db_id}] {sample.question}")
361
+
362
+ # Create adapter and pipeline for this database
363
+ adapter = SQLiteAdapter(sample.db_path)
364
+ pipeline = pipeline_from_config_with_adapter(args.config, adapter=adapter)
365
+
366
+ # Run pipeline
367
+ t0 = time.perf_counter()
368
+ result = run_pipeline_on_sample(pipeline, sample, schema_cache, args.debug)
369
+ latency_ms = int((time.perf_counter() - t0) * 1000)
370
+
371
+ # Evaluate
372
+ metrics = evaluate_sql_match(result["sql"], sample.gold_sql, sample.db_path)
373
+
374
+ # Store result
375
+ eval_result = {
376
+ "source": "spider",
377
+ "db_id": sample.db_id,
378
+ "query": sample.question,
379
+ "gold_sql": sample.gold_sql,
380
+ "pred_sql": result["sql"],
381
+ "ok": result["ok"],
382
+ "latency_ms": latency_ms,
383
+ "em": metrics["exact_match"],
384
+ "sm": metrics["set_match"],
385
+ "exec_acc": metrics["exec_accuracy"],
386
+ "error": result.get("error"),
387
+ "trace": result.get("traces", []),
388
+ }
389
+ results.append(eval_result)
390
+
391
+ # Debug output
392
+ if args.debug:
393
+ status = "✅" if result["ok"] and metrics["exact_match"] == 1 else "⚠️"
394
+ print(
395
+ f"{status} ({latency_ms} ms) | EM={metrics['exact_match']:.0f} SM={metrics['set_match']:.0f} ExecAcc={metrics['exec_accuracy']:.1f}"
396
  )
397
+ if metrics["exact_match"] < 1:
398
+ print(f" gold: {sample.gold_sql[:100]}")
399
+ print(f" pred: {result['sql'][:100] if result['sql'] else 'EMPTY'}")
400
+
401
+ # Calculate aggregates
402
+ total = len(results)
403
+ successful = sum(1 for r in results if r["ok"])
404
+ avg_em = sum(r["em"] for r in results) / total if total > 0 else 0
405
+ avg_sm = sum(r["sm"] for r in results) / total if total > 0 else 0
406
+ avg_ea = sum(r["exec_acc"] for r in results) / total if total > 0 else 0
407
+ avg_latency = sum(r["latency_ms"] for r in results) / total if total > 0 else 0
408
+
409
+ # Save results
410
+ eval_jsonl = RESULT_DIR / "eval.jsonl"
411
+ with open(eval_jsonl, "w") as f:
412
+ for r in results:
413
+ json.dump(r, f, ensure_ascii=False)
414
+ f.write("\n")
415
+
416
+ summary = {
417
+ "timestamp": datetime.now().isoformat(timespec="seconds"),
418
+ "total": total,
419
+ "success": successful,
420
+ "success_rate": round(successful / total, 3) if total else 0,
421
+ "avg_latency_ms": round(avg_latency, 1),
422
+ "EM": round(avg_em, 3),
423
+ "SM": round(avg_sm, 3),
424
+ "ExecAcc": round(avg_ea, 3),
425
+ "split": args.split,
426
+ "config": args.config,
427
+ }
428
+
429
+ (RESULT_DIR / "summary.json").write_text(
430
+ json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8"
431
+ )
432
+
433
+ print("\n================== Evaluation Summary ==================")
434
+ print(f"Total samples: {total}")
435
+ print(f"Successful runs: {successful} ({summary['success_rate'] * 100:.1f}%)")
436
+ print(f"Avg EM: {summary['EM']}")
437
+ print(f"Avg SM: {summary['SM']}")
438
+ print(f"Avg ExecAcc: {summary['ExecAcc']}")
439
+ print(f"Avg Latency: {summary['avg_latency_ms']} ms")
440
+ print(f"Results saved to {RESULT_DIR}")
441
+ print("========================================================")
442
 
443
 
444
  if __name__ == "__main__":
445
+ RESULT_DIR.mkdir(parents=True, exist_ok=True)
446
  main()
benchmarks/results_pro/20251108-123204/eval.jsonl DELETED
@@ -1,5 +0,0 @@
1
- {"source": "demo", "db_id": "demo", "query": "list all customers", "ok": false, "latency_ms": 8406, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 3768}, {"stage": "generator", "ms": 1616}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 3}, {"stage": "repair", "ms": 1639}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1367}, {"stage": "safety", "ms": 3}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
2
- {"source": "demo", "db_id": "demo", "query": "show total invoices per country", "ok": true, "latency_ms": 11003, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 5021}, {"stage": "generator", "ms": 1605}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1437}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 2929}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
3
- {"source": "demo", "db_id": "demo", "query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}], "error": null}
4
- {"source": "demo", "db_id": "demo", "query": "artists with more than 3 albums", "ok": false, "latency_ms": 14409, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 8377}, {"stage": "generator", "ms": 2525}, {"stage": "safety", "ms": 4}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1618}, {"stage": "safety", "ms": 4}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1874}, {"stage": "safety", "ms": 3}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
5
- {"source": "demo", "db_id": "demo", "query": "number of employees per city", "ok": true, "latency_ms": 8938, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4402}, {"stage": "generator", "ms": 1846}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1397}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1283}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
 
 
 
 
 
 
benchmarks/results_pro/20251108-123204/latency_per_stage.png DELETED
Binary file (34.7 kB)
 
benchmarks/results_pro/20251108-123204/metrics_overview.png DELETED
Binary file (22.7 kB)
 
benchmarks/results_pro/20251108-123204/results.csv DELETED
@@ -1,6 +0,0 @@
1
- source,db_id,query,em,sm,exec_acc,ok,latency_ms
2
- demo,demo,list all customers,,,,❌,8406
3
- demo,demo,show total invoices per country,,,,✅,11003
4
- demo,demo,top 3 albums by total sales,,,,✅,1
5
- demo,demo,artists with more than 3 albums,,,,❌,14409
6
- demo,demo,number of employees per city,,,,✅,8938
 
 
 
 
 
 
 
benchmarks/results_pro/20251108-123204/summary.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "mode": "single-db",
3
- "db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
4
- "config": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/configs/sqlite_pipeline.yaml",
5
- "provider_hint": "REAL",
6
- "total": 5,
7
- "EM": 0.0,
8
- "SM": 0.0,
9
- "ExecAcc": 0.0,
10
- "success_rate": 0.6,
11
- "avg_latency_ms": 8551.4,
12
- "timestamp": "2025-11-08 12:32:47"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/results_pro/20251108-124153/eval.jsonl DELETED
@@ -1,5 +0,0 @@
1
- {"source": "demo", "db_id": "demo", "query": "list all customers", "ok": false, "latency_ms": 6756, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 2729}, {"stage": "generator", "ms": 1343}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 2}, {"stage": "repair", "ms": 911}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1763}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
2
- {"source": "demo", "db_id": "demo", "query": "show total invoices per country", "ok": true, "latency_ms": 8901, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4799}, {"stage": "generator", "ms": 1075}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1092}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1924}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
3
- {"source": "demo", "db_id": "demo", "query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}], "error": null}
4
- {"source": "demo", "db_id": "demo", "query": "artists with more than 3 albums", "ok": false, "latency_ms": 12342, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4882}, {"stage": "generator", "ms": 2684}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 2630}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 2135}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}], "error": null}
5
- {"source": "demo", "db_id": "demo", "query": "number of employees per city", "ok": true, "latency_ms": 7547, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4083}, {"stage": "generator", "ms": 1269}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1149}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1035}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
 
 
 
 
 
 
benchmarks/results_pro/20251108-124153/latency_per_stage.png DELETED
Binary file (34.7 kB)
 
benchmarks/results_pro/20251108-124153/metrics_overview.png DELETED
Binary file (22.7 kB)
 
benchmarks/results_pro/20251108-124153/results.csv DELETED
@@ -1,6 +0,0 @@
1
- source,db_id,query,em,sm,exec_acc,ok,latency_ms
2
- demo,demo,list all customers,,,,❌,6756
3
- demo,demo,show total invoices per country,,,,✅,8901
4
- demo,demo,top 3 albums by total sales,,,,✅,1
5
- demo,demo,artists with more than 3 albums,,,,❌,12342
6
- demo,demo,number of employees per city,,,,✅,7547
 
 
 
 
 
 
 
benchmarks/results_pro/20251108-124153/summary.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "mode": "single-db",
3
- "db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
4
- "config": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/configs/sqlite_pipeline.yaml",
5
- "provider_hint": "REAL",
6
- "total": 5,
7
- "EM": 0.0,
8
- "SM": 0.0,
9
- "ExecAcc": 0.0,
10
- "success_rate": 0.6,
11
- "avg_latency_ms": 7109.4,
12
- "timestamp": "2025-11-08 12:42:29"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/results_pro/20251108-125829/eval.jsonl DELETED
@@ -1,5 +0,0 @@
1
- {"source": "demo", "db_id": "demo", "query": "list all customers", "ok": false, "latency_ms": 6652, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 2554}, {"stage": "generator", "ms": 1370}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 1}, {"stage": "repair", "ms": 1295}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "repair", "ms": 1426}, {"stage": "safety", "ms": 0}, {"stage": "executor", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
2
- {"source": "demo", "db_id": "demo", "query": "show total invoices per country", "ok": true, "latency_ms": 7375, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 3866}, {"stage": "generator", "ms": 1265}, {"stage": "safety", "ms": 4}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1126}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1106}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
3
- {"source": "demo", "db_id": "demo", "query": "top 3 albums by total sales", "ok": true, "latency_ms": 1, "trace": [{"stage": "detector", "ms": 0}], "error": null}
4
- {"source": "demo", "db_id": "demo", "query": "artists with more than 3 albums", "ok": false, "latency_ms": 8629, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 4110}, {"stage": "generator", "ms": 1969}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1296}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "repair", "ms": 1244}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
5
- {"source": "demo", "db_id": "demo", "query": "number of employees per city", "ok": true, "latency_ms": 5630, "trace": [{"stage": "detector", "ms": 0}, {"stage": "planner", "ms": 2602}, {"stage": "generator", "ms": 1097}, {"stage": "safety", "ms": 1}, {"stage": "executor", "ms": 0}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 1018}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "repair", "ms": 906}, {"stage": "safety", "ms": 2}, {"stage": "executor", "ms": 1}, {"stage": "verifier", "ms": 0}, {"stage": "pipeline", "ms": 0}, {"stage": "pipeline", "ms": 0}], "error": null}
 
 
 
 
 
 
benchmarks/results_pro/20251108-125829/latency_per_stage.png DELETED
Binary file (22.4 kB)
 
benchmarks/results_pro/20251108-125829/metrics_overview.png DELETED
Binary file (12.9 kB)
 
benchmarks/results_pro/20251108-125829/results.csv DELETED
@@ -1,6 +0,0 @@
1
- source,db_id,query,em,sm,exec_acc,ok,latency_ms
2
- demo,demo,list all customers,,,,❌,6652
3
- demo,demo,show total invoices per country,,,,✅,7375
4
- demo,demo,top 3 albums by total sales,,,,✅,1
5
- demo,demo,artists with more than 3 albums,,,,❌,8629
6
- demo,demo,number of employees per city,,,,✅,5630
 
 
 
 
 
 
 
benchmarks/results_pro/20251108-125829/summary.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "mode": "single-db",
3
- "db_path": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/demo.db",
4
- "config": "/Users/melikakheirieh/Desktop/my/career-developement/LLM/nl2sql-copilot/configs/sqlite_pipeline.yaml",
5
- "provider_hint": "REAL",
6
- "total": 5,
7
- "EM": 0.0,
8
- "SM": 0.0,
9
- "ExecAcc": 0.0,
10
- "success_rate": 0.6,
11
- "avg_latency_ms": 5657.4,
12
- "timestamp": "2025-11-08 12:58:58"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/results_pro/20251109-092540/eval.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 9423, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6884, "summary": "ok", "notes": {"len_plan": 1313}, "token_in": 270, "token_out": 313, "cost_usd": 0.0002283}, {"stage": "generator", "duration_ms": 891, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 801, "token_out": 19, "cost_usd": 0.00013155}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 673, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 962, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 9382, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6936, "summary": "ok", "notes": {"len_plan": 1501}, "token_in": 271, "token_out": 351, "cost_usd": 0.00025124999999999995}, {"stage": "generator", "duration_ms": 1014, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 840, "token_out": 19, "cost_usd": 0.00013739999999999998}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 2, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 710, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 710, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}}]}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 11380, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 7152, "summary": "ok", "notes": {"len_plan": 1281}, "token_in": 281, "token_out": 295, "cost_usd": 0.00021914999999999996}, {"stage": "generator", "duration_ms": 2189, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 794, "token_out": 37, "cost_usd": 0.0001413}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 954, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64}, "token_in": 325, "token_out": 21, "cost_usd": 6.135e-05}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1074, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64}, "token_in": 328, "token_out": 21, "cost_usd": 6.18e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 10894, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 7383, "summary": "ok", "notes": {"len_plan": 1579}, "token_in": 279, "token_out": 421, "cost_usd": 0.00029445}, {"stage": "generator", "duration_ms": 1242, "summary": "ok", "notes": {"rationale_len": 67}, "token_in": 918, "token_out": 42, "cost_usd": 0.00016289999999999998}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1078, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 333, "token_out": 24, "cost_usd": 6.435e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 3, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1173, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 337, "token_out": 28, "cost_usd": 6.735e-05}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
benchmarks/results_pro/20251109-092540/summary.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-11-09T09:26:21",
3
+ "total": 5,
4
+ "success": 5,
5
+ "success_rate": 1.0,
6
+ "avg_latency_ms": 8215.8,
7
+ "EM": 0.4,
8
+ "SM": 0.8,
9
+ "ExecAcc": 0.8,
10
+ "split": "dev",
11
+ "config": "configs/sqlite_pipeline.yaml"
12
+ }
benchmarks/results_pro/20251109-092823/eval.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 7982, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 5384, "summary": "ok", "notes": {"len_plan": 1287}, "token_in": 270, "token_out": 306, "cost_usd": 0.0002241}, {"stage": "generator", "duration_ms": 900, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 794, "token_out": 19, "cost_usd": 0.0001305}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 888, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 797, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 9717, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6881, "summary": "ok", "notes": {"len_plan": 1352}, "token_in": 271, "token_out": 319, "cost_usd": 0.00023204999999999998}, {"stage": "generator", "duration_ms": 1162, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 808, "token_out": 19, "cost_usd": 0.0001326}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 716, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 950, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}}]}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 8523, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 5311, "summary": "ok", "notes": {"len_plan": 1449}, "token_in": 281, "token_out": 343, "cost_usd": 0.00024795}, {"stage": "generator", "duration_ms": 1306, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 842, "token_out": 37, "cost_usd": 0.00014849999999999998}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 996, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64}, "token_in": 325, "token_out": 21, "cost_usd": 6.135e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 900, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64}, "token_in": 328, "token_out": 21, "cost_usd": 6.18e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 12291, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 8346, "summary": "ok", "notes": {"len_plan": 1363}, "token_in": 279, "token_out": 334, "cost_usd": 0.00024225}, {"stage": "generator", "duration_ms": 1636, "summary": "ok", "notes": {"rationale_len": 87}, "token_in": 831, "token_out": 46, "cost_usd": 0.00015225}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 2, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1137, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 333, "token_out": 25, "cost_usd": 6.495e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 3, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1151, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 337, "token_out": 28, "cost_usd": 6.735e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
benchmarks/results_pro/20251109-092823/summary.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-11-09T09:29:01",
3
+ "total": 5,
4
+ "success": 5,
5
+ "success_rate": 1.0,
6
+ "avg_latency_ms": 7702.6,
7
+ "EM": 0.4,
8
+ "SM": 0.8,
9
+ "ExecAcc": 0.8,
10
+ "split": "dev",
11
+ "config": "configs/sqlite_pipeline.yaml"
12
+ }
benchmarks/results_pro/20251109-093743/eval.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 10480, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 8010, "summary": "ok", "notes": {"len_plan": 1445}, "token_in": 270, "token_out": 337, "cost_usd": 0.00024270000000000002}, {"stage": "generator", "duration_ms": 1029, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 825, "token_out": 19, "cost_usd": 0.00013514999999999998}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 678, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 2, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 750, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer limit 1", "ok": true, "latency_ms": 10687, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 6978, "summary": "ok", "notes": {"len_plan": 1512}, "token_in": 271, "token_out": 355, "cost_usd": 0.00025364999999999996}, {"stage": "generator", "duration_ms": 2192, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 844, "token_out": 19, "cost_usd": 0.000138}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 652, "summary": "ok", "notes": {"old_sql_len": 27, "new_sql_len": 35}, "token_in": 318, "token_out": 8, "cost_usd": 5.2499999999999995e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 863, "summary": "ok", "notes": {"old_sql_len": 35, "new_sql_len": 35}, "token_in": 321, "token_out": 8, "cost_usd": 5.295e-05}, {"stage": "safety", "duration_ms": 0, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 1, "col_count": 1}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}}]}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc LIMIT 10", "ok": true, "latency_ms": 16736, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 13205, "summary": "ok", "notes": {"len_plan": 1758}, "token_in": 281, "token_out": 409, "cost_usd": 0.00028754999999999997}, {"stage": "generator", "duration_ms": 1537, "summary": "ok", "notes": {"rationale_len": 83}, "token_in": 908, "token_out": 37, "cost_usd": 0.0001584}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1019, "summary": "ok", "notes": {"old_sql_len": 55, "new_sql_len": 64}, "token_in": 325, "token_out": 21, "cost_usd": 6.135e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 0, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 968, "summary": "ok", "notes": {"old_sql_len": 64, "new_sql_len": 64}, "token_in": 328, "token_out": 21, "cost_usd": 6.18e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"issues": ["exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": true, "latency_ms": 12440, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}}, {"stage": "planner", "duration_ms": 7973, "summary": "ok", "notes": {"len_plan": 1377}, "token_in": 279, "token_out": 345, "cost_usd": 0.00024884999999999995}, {"stage": "generator", "duration_ms": 1827, "summary": "ok", "notes": {"rationale_len": 94}, "token_in": 841, "token_out": 47, "cost_usd": 0.00015434999999999998}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1312, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 333, "token_out": 24, "cost_usd": 6.435e-05}, {"stage": "safety", "duration_ms": 3, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 2, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "repair", "duration_ms": 1313, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 337, "token_out": 21, "cost_usd": 6.315e-05}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "verifier", "duration_ms": 1, "summary": "ok", "notes": {"issues": ["aggregation_without_group_by", "exec_error:preview_failed"]}, "token_in": null, "token_out": null, "cost_usd": null}, {"stage": "pipeline", "duration_ms": 0, "summary": "auto-verified", "notes": {"reason": "executor succeeded, verifier silent"}}, {"stage": "pipeline", "duration_ms": 0, "summary": "finalize", "notes": {"final_verified": true, "details_len": 0, "need_verification": false}}]}
benchmarks/results_pro/20251109-093743/summary.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-11-09T09:38:33",
3
+ "total": 5,
4
+ "success": 5,
5
+ "success_rate": 1.0,
6
+ "avg_latency_ms": 10068.6,
7
+ "EM": 0.4,
8
+ "SM": 0.8,
9
+ "ExecAcc": 0.8,
10
+ "split": "dev",
11
+ "config": "configs/sqlite_pipeline.yaml"
12
+ }
nl2sql/pipeline.py CHANGED
@@ -31,9 +31,8 @@ class FinalResult:
31
 
32
  class Pipeline:
33
  """
34
- NL2SQL Copilot pipeline.
35
- Stages return StageResult; final result is a type-safe FinalResult.
36
- DI-ready: all dependencies are injected via __init__.
37
  """
38
 
39
  def __init__(
@@ -54,22 +53,21 @@ class Pipeline:
54
  self.executor = executor or NoOpExecutor()
55
  self.verifier = verifier or NoOpVerifier()
56
  self.repair = repair or NoOpRepair()
 
 
57
 
58
- # ------------------------------------------------------------
59
  @staticmethod
60
  def _trace_list(*stages: Optional[StageResult]) -> List[dict]:
61
- """Collect .trace objects (as dict) from StageResult items if present."""
62
  traces: List[dict] = []
63
  for s in stages:
64
  if not s:
65
  continue
66
  t = getattr(s, "trace", None)
67
  if t is not None:
68
- # t is likely a dataclass – expose as plain dict for JSON safety
69
  traces.append(getattr(t, "__dict__", t))
70
  return traces
71
 
72
- # ------------------------------------------------------------
73
  @staticmethod
74
  def _mk_trace(
75
  stage: str,
@@ -77,7 +75,6 @@ class Pipeline:
77
  summary: str,
78
  notes: Optional[Dict[str, Any]] = None,
79
  ) -> dict:
80
- """Create a normalized trace dict (internal: duration may be float)."""
81
  return {
82
  "stage": stage,
83
  "duration_ms": float(duration_ms),
@@ -87,11 +84,6 @@ class Pipeline:
87
 
88
  @staticmethod
89
  def _normalize_traces(traces: List[dict]) -> List[dict]:
90
- """
91
- Normalize trace list for API/UI:
92
- - coerce duration_ms to int
93
- - ensure `summary` exists (fallback to a minimal one)
94
- """
95
  norm: List[dict] = []
96
  for t in traces:
97
  stage = str(t.get("stage", "unknown"))
@@ -100,37 +92,24 @@ class Pipeline:
100
  dur_int = int(round(float(dur)))
101
  except Exception:
102
  dur_int = 0
103
- summary = t.get("summary")
104
- if not summary:
105
- # fallback summary if not provided by stage
106
- notes = t.get("notes") or {}
107
- failed = bool(notes.get("error") or notes.get("errors"))
108
- summary = "failed" if failed else "ok"
109
  notes = t.get("notes") or {}
110
- # preserve any accounting fields if present (token_in/out, cost_usd, ...)
 
 
111
  payload = {
112
  "stage": stage,
113
  "duration_ms": dur_int,
114
  "summary": summary,
115
  "notes": notes,
116
  }
117
- # keep extra accounting if exists
118
- if "token_in" in t:
119
- payload["token_in"] = t["token_in"]
120
- if "token_out" in t:
121
- payload["token_out"] = t["token_out"]
122
- if "cost_usd" in t:
123
- payload["cost_usd"] = t["cost_usd"]
124
  norm.append(payload)
125
  return norm
126
 
127
- # ------------------------------------------------------------
128
  @staticmethod
129
  def _safe_stage(fn, **kwargs) -> StageResult:
130
- """
131
- Run a stage safely; if it throws, return a StageResult(ok=False, error=[...]).
132
- If fn returns a non-StageResult (e.g., dict), coerce to StageResult(ok=True, data=...).
133
- """
134
  try:
135
  r = fn(**kwargs)
136
  if isinstance(r, StageResult):
@@ -140,7 +119,7 @@ class Pipeline:
140
  tb = traceback.format_exc()
141
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
142
 
143
- # ------------------------------------------------------------
144
  def run(
145
  self,
146
  *,
@@ -152,7 +131,6 @@ class Pipeline:
152
  traces: List[dict] = []
153
  details: List[str] = []
154
 
155
- # Always push a normalized per-stage timing, even if StageResult.trace is empty
156
  def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
157
  traces.append(
158
  self._mk_trace(
@@ -162,26 +140,24 @@ class Pipeline:
162
  )
163
  )
164
 
165
- # Normalize inputs
166
  schema_preview = schema_preview or ""
167
  clarify_answers = clarify_answers or {}
168
 
169
  try:
170
  # --- 1) detector ---
171
- t_det0 = time.perf_counter()
172
  questions = self.detector.detect(user_query, schema_preview)
173
- det_ms = (time.perf_counter() - t_det0) * 1000.0
174
  is_amb = bool(questions)
175
- stage_duration_ms.labels("detector").observe(det_ms)
176
  traces.append(
177
  self._mk_trace(
178
  stage="detector",
179
- duration_ms=det_ms,
180
  summary=("ambiguous" if is_amb else "clear"),
181
  notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
182
  )
183
  )
184
-
185
  if questions:
186
  pipeline_runs_total.labels(status="ambiguous").inc()
187
  return FinalResult(
@@ -197,15 +173,15 @@ class Pipeline:
197
  )
198
 
199
  # --- 2) planner ---
200
- t_pln0 = time.perf_counter()
201
  r_plan = self._safe_stage(
202
  self.planner.run, user_query=user_query, schema_preview=schema_preview
203
  )
204
- pln_ms = (time.perf_counter() - t_pln0) * 1000.0
205
- stage_duration_ms.labels("planner").observe(pln_ms)
206
  traces.extend(self._trace_list(r_plan))
207
  if not getattr(r_plan, "trace", None):
208
- _fallback_trace("planner", pln_ms, r_plan.ok)
209
  if not r_plan.ok:
210
  pipeline_runs_total.labels(status="error").inc()
211
  return FinalResult(
@@ -221,7 +197,7 @@ class Pipeline:
221
  )
222
 
223
  # --- 3) generator ---
224
- t_gen0 = time.perf_counter()
225
  r_gen = self._safe_stage(
226
  self.generator.run,
227
  user_query=user_query,
@@ -229,11 +205,11 @@ class Pipeline:
229
  plan_text=(r_plan.data or {}).get("plan"),
230
  clarify_answers=clarify_answers,
231
  )
232
- gen_ms = (time.perf_counter() - t_gen0) * 1000.0
233
- stage_duration_ms.labels("generator").observe(gen_ms)
234
  traces.extend(self._trace_list(r_gen))
235
  if not getattr(r_gen, "trace", None):
236
- _fallback_trace("generator", gen_ms, r_gen.ok)
237
  if not r_gen.ok:
238
  pipeline_runs_total.labels(status="error").inc()
239
  return FinalResult(
@@ -251,14 +227,32 @@ class Pipeline:
251
  sql = (r_gen.data or {}).get("sql")
252
  rationale = (r_gen.data or {}).get("rationale")
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # --- 4) safety ---
255
- t_saf0 = time.perf_counter()
256
  r_safe = self._safe_stage(self.safety.run, sql=sql)
257
- saf_ms = (time.perf_counter() - t_saf0) * 1000.0
258
- stage_duration_ms.labels("safety").observe(saf_ms)
259
  traces.extend(self._trace_list(r_safe))
260
  if not getattr(r_safe, "trace", None):
261
- _fallback_trace("safety", saf_ms, r_safe.ok)
262
  if not r_safe.ok:
263
  pipeline_runs_total.labels(status="error").inc()
264
  return FinalResult(
@@ -273,99 +267,112 @@ class Pipeline:
273
  traces=self._normalize_traces(traces),
274
  )
275
 
 
 
 
276
  # --- 5) executor ---
277
- t_exe0 = time.perf_counter()
278
- r_exec = self._safe_stage(
279
- self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
280
- )
281
- exe_ms = (time.perf_counter() - t_exe0) * 1000.0
282
- stage_duration_ms.labels("executor").observe(exe_ms)
283
  traces.extend(self._trace_list(r_exec))
284
  if not getattr(r_exec, "trace", None):
285
- _fallback_trace("executor", exe_ms, r_exec.ok)
286
  if not r_exec.ok and r_exec.error:
287
- # executor failure is soft; collect for repair/verifier context
288
- details.extend(r_exec.error)
289
 
290
  # --- 6) verifier ---
291
- t_ver0 = time.perf_counter()
292
  r_ver = self._safe_stage(
293
- self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
 
 
 
 
 
294
  )
295
- ver_ms = (time.perf_counter() - t_ver0) * 1000.0
296
- stage_duration_ms.labels("verifier").observe(ver_ms)
297
  traces.extend(self._trace_list(r_ver))
298
  if not getattr(r_ver, "trace", None):
299
- _fallback_trace("verifier", ver_ms, r_ver.ok)
300
  verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
301
 
302
- # --- 7) repair loop if verification failed ---
 
 
 
 
303
  if not verified:
304
  for _attempt in range(2):
305
  # repair
306
- t_fix0 = time.perf_counter()
307
  r_fix = self._safe_stage(
308
  self.repair.run,
309
  sql=sql,
310
  error_msg="; ".join(details or ["unknown"]),
311
  schema_preview=schema_preview,
312
  )
313
- fix_ms = (time.perf_counter() - t_fix0) * 1000.0
314
- stage_duration_ms.labels("repair").observe(fix_ms)
315
  traces.extend(self._trace_list(r_fix))
316
  if not getattr(r_fix, "trace", None):
317
- _fallback_trace("repair", fix_ms, r_fix.ok)
318
  if not r_fix.ok:
319
- break # give up on repair
320
 
321
- # fixed SQL
322
  sql = (r_fix.data or {}).get("sql", sql)
323
 
324
- # safety
325
- t_saf0 = time.perf_counter()
326
- r_safe = self._safe_stage(self.safety.run, sql=sql)
327
- saf_ms2 = (time.perf_counter() - t_saf0) * 1000.0
328
- stage_duration_ms.labels("safety").observe(saf_ms2)
329
- traces.extend(self._trace_list(r_safe))
330
- if not getattr(r_safe, "trace", None):
331
- _fallback_trace("safety", saf_ms2, r_safe.ok)
332
- if not r_safe.ok:
333
- if r_safe.error:
334
- details.extend(r_safe.error)
335
  continue
336
-
337
- # executor
338
- t_exe0 = time.perf_counter()
339
- r_exec = self._safe_stage(
340
- self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
341
- )
342
- exe_ms2 = (time.perf_counter() - t_exe0) * 1000.0
343
- stage_duration_ms.labels("executor").observe(exe_ms2)
344
- traces.extend(self._trace_list(r_exec))
345
- if not getattr(r_exec, "trace", None):
346
- _fallback_trace("executor", exe_ms2, r_exec.ok)
347
- if not r_exec.ok:
348
- if r_exec.error:
349
- details.extend(r_exec.error)
350
  continue
351
 
352
- # verifier
353
- t_ver0 = time.perf_counter()
354
- r_ver = self._safe_stage(
355
- self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
 
 
 
356
  )
357
- ver_ms2 = (time.perf_counter() - t_ver0) * 1000.0
358
- stage_duration_ms.labels("verifier").observe(ver_ms2)
359
- traces.extend(self._trace_list(r_ver))
360
- if not getattr(r_ver, "trace", None):
361
- _fallback_trace("verifier", ver_ms2, r_ver.ok)
362
  verified = (
363
- bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
364
  )
 
 
365
  if verified:
366
  break
367
 
368
- # --- 8) fallback: verifier silent but executor succeeded ---
369
  if (verified is None or not verified) and not details:
370
  any_exec_ok = any(
371
  t.get("stage") == "executor"
@@ -385,13 +392,24 @@ class Pipeline:
385
 
386
  # --- 9) finalize ---
387
  has_errors = bool(details)
388
- ok = bool(verified) and not has_errors
389
- err = has_errors and not bool(verified)
 
 
 
 
 
 
 
390
 
391
- if ok:
392
- pipeline_runs_total.labels(status="ok").inc()
 
 
393
  else:
394
- pipeline_runs_total.labels(status="error").inc()
 
 
395
 
396
  traces.append(
397
  self._mk_trace(
@@ -399,8 +417,9 @@ class Pipeline:
399
  duration_ms=0.0,
400
  summary="finalize",
401
  notes={
402
- "final_verified": bool(verified),
403
  "details_len": len(details),
 
404
  },
405
  )
406
  )
@@ -412,18 +431,18 @@ class Pipeline:
412
  details=details or None,
413
  sql=sql,
414
  rationale=rationale,
415
- verified=verified,
416
  questions=None,
417
  traces=self._normalize_traces(traces),
418
  )
419
 
420
  except Exception:
421
- # Any unexpected crash
422
  pipeline_runs_total.labels(status="error").inc()
 
423
  raise
424
 
425
  finally:
426
- # Always record total latency even on early-return / exceptions
427
  stage_duration_ms.labels("pipeline_total").observe(
428
  (time.perf_counter() - t_all0) * 1000.0
429
  )
 
31
 
32
  class Pipeline:
33
  """
34
+ NL2SQL Copilot pipeline:
35
+ detector planner generator safety executor → verifier → (optional repair loop).
 
36
  """
37
 
38
  def __init__(
 
53
  self.executor = executor or NoOpExecutor()
54
  self.verifier = verifier or NoOpVerifier()
55
  self.repair = repair or NoOpRepair()
56
+ # If the verifier explicitly requires verification, enforce it in finalize.
57
+ self.require_verification = bool(getattr(self.verifier, "required", False))
58
 
59
+ # ---------------------------- helpers ----------------------------
60
  @staticmethod
61
  def _trace_list(*stages: Optional[StageResult]) -> List[dict]:
 
62
  traces: List[dict] = []
63
  for s in stages:
64
  if not s:
65
  continue
66
  t = getattr(s, "trace", None)
67
  if t is not None:
 
68
  traces.append(getattr(t, "__dict__", t))
69
  return traces
70
 
 
71
  @staticmethod
72
  def _mk_trace(
73
  stage: str,
 
75
  summary: str,
76
  notes: Optional[Dict[str, Any]] = None,
77
  ) -> dict:
 
78
  return {
79
  "stage": stage,
80
  "duration_ms": float(duration_ms),
 
84
 
85
  @staticmethod
86
  def _normalize_traces(traces: List[dict]) -> List[dict]:
 
 
 
 
 
87
  norm: List[dict] = []
88
  for t in traces:
89
  stage = str(t.get("stage", "unknown"))
 
92
  dur_int = int(round(float(dur)))
93
  except Exception:
94
  dur_int = 0
 
 
 
 
 
 
95
  notes = t.get("notes") or {}
96
+ summary = t.get("summary") or (
97
+ "failed" if (notes.get("error") or notes.get("errors")) else "ok"
98
+ )
99
  payload = {
100
  "stage": stage,
101
  "duration_ms": dur_int,
102
  "summary": summary,
103
  "notes": notes,
104
  }
105
+ for k in ("token_in", "token_out", "cost_usd"):
106
+ if k in t:
107
+ payload[k] = t[k]
 
 
 
 
108
  norm.append(payload)
109
  return norm
110
 
 
111
  @staticmethod
112
  def _safe_stage(fn, **kwargs) -> StageResult:
 
 
 
 
113
  try:
114
  r = fn(**kwargs)
115
  if isinstance(r, StageResult):
 
119
  tb = traceback.format_exc()
120
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
121
 
122
+ # ------------------------------ run ------------------------------
123
  def run(
124
  self,
125
  *,
 
131
  traces: List[dict] = []
132
  details: List[str] = []
133
 
 
134
  def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
135
  traces.append(
136
  self._mk_trace(
 
140
  )
141
  )
142
 
 
143
  schema_preview = schema_preview or ""
144
  clarify_answers = clarify_answers or {}
145
 
146
  try:
147
  # --- 1) detector ---
148
+ t0 = time.perf_counter()
149
  questions = self.detector.detect(user_query, schema_preview)
150
+ dt = (time.perf_counter() - t0) * 1000.0
151
  is_amb = bool(questions)
152
+ stage_duration_ms.labels("detector").observe(dt)
153
  traces.append(
154
  self._mk_trace(
155
  stage="detector",
156
+ duration_ms=dt,
157
  summary=("ambiguous" if is_amb else "clear"),
158
  notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
159
  )
160
  )
 
161
  if questions:
162
  pipeline_runs_total.labels(status="ambiguous").inc()
163
  return FinalResult(
 
173
  )
174
 
175
  # --- 2) planner ---
176
+ t0 = time.perf_counter()
177
  r_plan = self._safe_stage(
178
  self.planner.run, user_query=user_query, schema_preview=schema_preview
179
  )
180
+ dt = (time.perf_counter() - t0) * 1000.0
181
+ stage_duration_ms.labels("planner").observe(dt)
182
  traces.extend(self._trace_list(r_plan))
183
  if not getattr(r_plan, "trace", None):
184
+ _fallback_trace("planner", dt, r_plan.ok)
185
  if not r_plan.ok:
186
  pipeline_runs_total.labels(status="error").inc()
187
  return FinalResult(
 
197
  )
198
 
199
  # --- 3) generator ---
200
+ t0 = time.perf_counter()
201
  r_gen = self._safe_stage(
202
  self.generator.run,
203
  user_query=user_query,
 
205
  plan_text=(r_plan.data or {}).get("plan"),
206
  clarify_answers=clarify_answers,
207
  )
208
+ dt = (time.perf_counter() - t0) * 1000.0
209
+ stage_duration_ms.labels("generator").observe(dt)
210
  traces.extend(self._trace_list(r_gen))
211
  if not getattr(r_gen, "trace", None):
212
+ _fallback_trace("generator", dt, r_gen.ok)
213
  if not r_gen.ok:
214
  pipeline_runs_total.labels(status="error").inc()
215
  return FinalResult(
 
227
  sql = (r_gen.data or {}).get("sql")
228
  rationale = (r_gen.data or {}).get("rationale")
229
 
230
+ # Guard: empty SQL
231
+ if not sql or not str(sql).strip():
232
+ pipeline_runs_total.labels(status="error").inc()
233
+ traces.append(
234
+ self._mk_trace("generator", 0.0, "failed", {"reason": "empty_sql"})
235
+ )
236
+ return FinalResult(
237
+ ok=False,
238
+ ambiguous=False,
239
+ error=True,
240
+ details=["empty_sql"],
241
+ questions=None,
242
+ sql=None,
243
+ rationale=rationale,
244
+ verified=None,
245
+ traces=self._normalize_traces(traces),
246
+ )
247
+
248
  # --- 4) safety ---
249
+ t0 = time.perf_counter()
250
  r_safe = self._safe_stage(self.safety.run, sql=sql)
251
+ dt = (time.perf_counter() - t0) * 1000.0
252
+ stage_duration_ms.labels("safety").observe(dt)
253
  traces.extend(self._trace_list(r_safe))
254
  if not getattr(r_safe, "trace", None):
255
+ _fallback_trace("safety", dt, r_safe.ok)
256
  if not r_safe.ok:
257
  pipeline_runs_total.labels(status="error").inc()
258
  return FinalResult(
 
267
  traces=self._normalize_traces(traces),
268
  )
269
 
270
+ # Use sanitized SQL from safety
271
+ sql = (r_safe.data or {}).get("sql", sql)
272
+
273
  # --- 5) executor ---
274
+ t0 = time.perf_counter()
275
+ r_exec = self._safe_stage(self.executor.run, sql=sql)
276
+ dt = (time.perf_counter() - t0) * 1000.0
277
+ stage_duration_ms.labels("executor").observe(dt)
 
 
278
  traces.extend(self._trace_list(r_exec))
279
  if not getattr(r_exec, "trace", None):
280
+ _fallback_trace("executor", dt, r_exec.ok)
281
  if not r_exec.ok and r_exec.error:
282
+ details.extend(r_exec.error) # soft: keep for repair/verifier context
 
283
 
284
  # --- 6) verifier ---
285
+ t0 = time.perf_counter()
286
  r_ver = self._safe_stage(
287
+ self.verifier.run,
288
+ sql=sql,
289
+ exec_result=(r_exec.data or {}),
290
+ adapter=getattr(
291
+ self.executor, "adapter", None
292
+ ), # let verifier use adapter
293
  )
294
+ dt = (time.perf_counter() - t0) * 1000.0
295
+ stage_duration_ms.labels("verifier").observe(dt)
296
  traces.extend(self._trace_list(r_ver))
297
  if not getattr(r_ver, "trace", None):
298
+ _fallback_trace("verifier", dt, r_ver.ok)
299
  verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
300
 
301
+ # consume repaired SQL from verifier if any
302
+ if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
303
+ sql = r_ver.data["sql"]
304
+
305
+ # --- 7) repair loop (if not verified) ---
306
  if not verified:
307
  for _attempt in range(2):
308
  # repair
309
+ t0 = time.perf_counter()
310
  r_fix = self._safe_stage(
311
  self.repair.run,
312
  sql=sql,
313
  error_msg="; ".join(details or ["unknown"]),
314
  schema_preview=schema_preview,
315
  )
316
+ dt = (time.perf_counter() - t0) * 1000.0
317
+ stage_duration_ms.labels("repair").observe(dt)
318
  traces.extend(self._trace_list(r_fix))
319
  if not getattr(r_fix, "trace", None):
320
+ _fallback_trace("repair", dt, r_fix.ok)
321
  if not r_fix.ok:
322
+ break
323
 
324
+ # update SQL
325
  sql = (r_fix.data or {}).get("sql", sql)
326
 
327
+ # safety again
328
+ t0 = time.perf_counter()
329
+ r_safe2 = self._safe_stage(self.safety.run, sql=sql)
330
+ dt2 = (time.perf_counter() - t0) * 1000.0
331
+ stage_duration_ms.labels("safety").observe(dt2)
332
+ traces.extend(self._trace_list(r_safe2))
333
+ if not getattr(r_safe2, "trace", None):
334
+ _fallback_trace("safety", dt2, r_safe2.ok)
335
+ if not r_safe2.ok:
336
+ if r_safe2.error:
337
+ details.extend(r_safe2.error)
338
  continue
339
+ sql = (r_safe2.data or {}).get("sql", sql)
340
+
341
+ # executor again
342
+ t0 = time.perf_counter()
343
+ r_exec2 = self._safe_stage(self.executor.run, sql=sql)
344
+ dt2 = (time.perf_counter() - t0) * 1000.0
345
+ stage_duration_ms.labels("executor").observe(dt2)
346
+ traces.extend(self._trace_list(r_exec2))
347
+ if not getattr(r_exec2, "trace", None):
348
+ _fallback_trace("executor", dt2, r_exec2.ok)
349
+ if not r_exec2.ok:
350
+ if r_exec2.error:
351
+ details.extend(r_exec2.error)
 
352
  continue
353
 
354
+ # verifier again
355
+ t0 = time.perf_counter()
356
+ r_ver2 = self._safe_stage(
357
+ self.verifier.run,
358
+ sql=sql,
359
+ exec_result=(r_exec2.data or {}),
360
+ adapter=getattr(self.executor, "adapter", None),
361
  )
362
+ dt2 = (time.perf_counter() - t0) * 1000.0
363
+ stage_duration_ms.labels("verifier").observe(dt2)
364
+ traces.extend(self._trace_list(r_ver2))
365
+ if not getattr(r_ver2, "trace", None):
366
+ _fallback_trace("verifier", dt2, r_ver2.ok)
367
  verified = (
368
+ bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok
369
  )
370
+ if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
371
+ sql = r_ver2.data["sql"]
372
  if verified:
373
  break
374
 
375
+ # --- 8) optional soft auto-verify (executor success, no details) ---
376
  if (verified is None or not verified) and not details:
377
  any_exec_ok = any(
378
  t.get("stage") == "executor"
 
392
 
393
  # --- 9) finalize ---
394
  has_errors = bool(details)
395
+ need_ver = bool(self.require_verification)
396
+
397
+ # base success condition
398
+ final_ok_by_verifier = bool(verified)
399
+ base_ok = (
400
+ bool(sql) and not has_errors and (final_ok_by_verifier or not need_ver)
401
+ )
402
+ ok = base_ok
403
+ err = (not ok) and has_errors
404
 
405
+ # align `verified` with baseline semantics:
406
+ # if verification is NOT required and pipeline is ok, report verified=True
407
+ if not need_ver and ok and not final_ok_by_verifier:
408
+ verified_final = True
409
  else:
410
+ verified_final = bool(verified)
411
+
412
+ pipeline_runs_total.labels(status=("ok" if ok else "error")).inc()
413
 
414
  traces.append(
415
  self._mk_trace(
 
417
  duration_ms=0.0,
418
  summary="finalize",
419
  notes={
420
+ "final_verified": bool(verified_final),
421
  "details_len": len(details),
422
+ "need_verification": need_ver,
423
  },
424
  )
425
  )
 
431
  details=details or None,
432
  sql=sql,
433
  rationale=rationale,
434
+ verified=verified_final,
435
  questions=None,
436
  traces=self._normalize_traces(traces),
437
  )
438
 
439
  except Exception:
 
440
  pipeline_runs_total.labels(status="error").inc()
441
+ # bubble up to make failures visible in tests and logs
442
  raise
443
 
444
  finally:
445
+ # Always record total latency, even on early return/exception
446
  stage_duration_ms.labels("pipeline_total").observe(
447
  (time.perf_counter() - t_all0) * 1000.0
448
  )
nl2sql/verifier.py CHANGED
@@ -1,8 +1,7 @@
1
  from __future__ import annotations
2
-
3
  import re
4
  import time
5
- from typing import Any, Iterable, List, Optional
6
 
7
  import sqlglot
8
  from sqlglot import expressions as exp
@@ -10,24 +9,65 @@ from sqlglot import expressions as exp
10
  from nl2sql.types import StageResult, StageTrace
11
  from nl2sql.metrics import (
12
  verifier_checks_total,
13
- stage_duration_ms,
14
  verifier_failures_total,
15
  )
16
 
17
 
18
  def _ms(t0: float) -> int:
 
19
  return int((time.perf_counter() - t0) * 1000)
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class Verifier:
23
  name = "verifier"
24
 
25
- # Textual fallback: scan for common aggregate calls
26
  _AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
27
 
28
- # ----------------------- AST helpers (version-friendly) --------------------
 
 
 
 
29
  def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
30
- """Non-recursive DFS over sqlglot Expression tree (avoid private APIs)."""
31
  stack = [node]
32
  while stack:
33
  cur = stack.pop()
@@ -43,6 +83,7 @@ class Verifier:
43
  stack.append(it)
44
 
45
  def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
 
46
  for n in self._walk(tree):
47
  if isinstance(n, exp.Select):
48
  return n
@@ -50,27 +91,22 @@ class Verifier:
50
 
51
  def _has_group_by(self, tree: exp.Expression) -> bool:
52
  sel = self._first_select(tree)
53
- if not sel:
54
- return False
55
- # sqlglot stores GROUP BY on Select.group
56
- return bool(getattr(sel, "group", None))
57
 
58
  def _is_distinct_projection(self, tree: exp.Expression) -> bool:
59
  sel = self._first_select(tree)
60
  if not sel:
61
  return False
62
- # DISTINCT may appear as Select.distinct or a Distinct node
63
  if getattr(sel, "distinct", None):
64
  return True
65
  return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
66
 
67
  def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
68
- # If there is any OVER(...) window, aggregates without GROUP BY can be legitimate
69
  return any(isinstance(n, exp.Window) for n in self._walk(tree))
70
 
71
  def _expr_contains_agg(self, node: exp.Expression) -> bool:
72
- """True if subtree contains an aggregate call (robust across sqlglot versions)."""
73
- # Build aggregate classes dynamically to avoid attr errors and fixed-length tuples
74
  agg_type_names = (
75
  "Count",
76
  "Sum",
@@ -81,26 +117,24 @@ class Verifier:
81
  "ArrayAgg",
82
  "StringAgg",
83
  )
84
- agg_types_list: list[type] = []
85
- for name in agg_type_names:
86
- t = getattr(exp, name, None)
87
- if isinstance(t, type):
88
- agg_types_list.append(t)
89
- AGG_TYPES: tuple[type, ...] = tuple(agg_types_list)
90
-
91
- # 1) Class-based check (if we found any known aggregate classes)
92
- if AGG_TYPES and any(isinstance(n, AGG_TYPES) for n in self._walk(node)):
93
  return True
94
 
95
- # 2) Fallback: generic function nodes with aggregate names
96
  Anonymous = getattr(exp, "Anonymous", None)
97
  func_like = (exp.Func,) + ((Anonymous,) if isinstance(Anonymous, type) else ())
98
- AGG_NAMES = {"count", "sum", "avg", "min", "max"}
99
 
100
- def _func_name(n: exp.Expression) -> str:
101
- name = getattr(n, "name", None)
102
- if isinstance(name, str) and name:
103
- return name.lower()
104
  this = getattr(n, "this", None)
105
  if isinstance(this, str):
106
  return this.lower()
@@ -110,82 +144,138 @@ class Verifier:
110
  return (str(this) or "").lower()
111
 
112
  for n in self._walk(node):
113
- if isinstance(n, func_like) and _func_name(n) in AGG_NAMES:
114
- return True
115
-
116
- return False
117
-
118
- def _has_nonagg_column(self, node: exp.Expression) -> bool:
119
- """Subtree contains a column reference that is NOT inside an aggregate."""
120
- # Check if there are any columns in this expression
121
- columns = [n for n in self._walk(node) if isinstance(n, exp.Column)]
122
- if not columns:
123
- return False
124
-
125
- # Check if all columns are inside aggregates
126
- for col in columns:
127
- # Walk up from column to see if it's inside an aggregate
128
- # is_in_agg = False
129
- # For simplicity, check if the entire expression contains both column and aggregate
130
- # A more precise check would require parent tracking
131
- if self._expr_contains_agg(node):
132
- # This is a simplified check - if the node has both columns and aggregates,
133
- # we need more complex logic to determine if columns are outside aggregates
134
- return True
135
- else:
136
- # No aggregates, so if there are columns, they're non-aggregate
137
  return True
138
  return False
139
 
140
- # ----------------------- Textual fallback helpers -------------------------
141
  def _clean_sql_for_fn_scan(self, sql: str) -> str:
142
- """Remove comments/strings so regex won't be fooled."""
143
  s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
144
  s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
145
  s = re.sub(
146
  r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
147
- ) # quoted strings / idents
148
  s = re.sub(r"\s+", " ", s).strip()
149
  return s
150
 
151
- # ----------------------- Adapter result helpers ---------------------------
152
- def _extract_ok(self, exec_result: Any) -> Optional[bool]:
153
- if isinstance(exec_result, dict):
154
- v = exec_result.get("ok")
155
- if isinstance(v, bool):
156
- return v
 
 
 
 
 
 
 
157
  return None
158
 
159
- def _extract_error(self, exec_result: Any) -> Optional[str]:
160
- if isinstance(exec_result, dict):
161
- for k in ("error", "message", "detail"):
162
- if k in exec_result and exec_result[k]:
163
- return str(exec_result[k])
164
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- # ----------------------------- Main entry ---------------------------------
167
- def verify(self, sql: str, *, adapter: Any) -> StageResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  t0 = time.perf_counter()
169
  issues: List[str] = []
 
170
 
171
- # 1) Parse - Check for errors in the parsed result
172
- try:
173
- tree = sqlglot.parse_one(sql, read=None) # autodetect dialect
 
 
 
 
 
 
 
174
 
175
- # Check if the parse actually succeeded
 
 
176
  if tree is None:
177
  return StageResult(
178
  ok=False,
179
  error=["parse_error"],
180
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
181
  )
182
-
183
- # sqlglot may parse broken SQL as an "Unknown" or "Command" type
184
- # Check if we got a proper SQL statement type
185
  tree_type = type(tree).__name__
186
-
187
- # Check for common sqlglot error indicators
188
- # When sqlglot can't parse properly, it often creates Command or Unknown nodes
189
  if tree_type in ("Command", "Unknown"):
190
  verifier_checks_total.labels(ok="false").inc()
191
  verifier_failures_total.labels(reason="parse_error").inc()
@@ -194,36 +284,6 @@ class Verifier:
194
  error=["parse_error"],
195
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
196
  )
197
-
198
- # Also check if the tree has errors attribute (some versions of sqlglot)
199
- if hasattr(tree, "errors") and tree.errors:
200
- verifier_checks_total.labels(ok="false").inc()
201
- verifier_failures_total.labels(reason="parse_error").inc()
202
- return StageResult(
203
- ok=False,
204
- error=["parse_error"],
205
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
206
- )
207
-
208
- # Additional check: if it's not a recognized DML/DQL statement
209
- valid_types = ("Select", "With", "Union", "Intersect", "Except", "Values")
210
- if tree_type not in valid_types:
211
- # This might be a parse error disguised as a different statement type
212
- # Let's check if it looks like it should be a SELECT
213
- sql_lower = sql.lower().strip()
214
- if any(
215
- sql_lower.startswith(kw)
216
- for kw in ["selct", "slect", "selet", "seelct"]
217
- ):
218
- # Common misspellings of SELECT
219
- verifier_checks_total.labels(ok="false").inc()
220
- verifier_failures_total.labels(reason="parse_error").inc()
221
- return StageResult(
222
- ok=False,
223
- error=["parse_error"],
224
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
225
- )
226
-
227
  except Exception:
228
  verifier_checks_total.labels(ok="false").inc()
229
  verifier_failures_total.labels(reason="parse_error").inc()
@@ -233,29 +293,22 @@ class Verifier:
233
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
234
  )
235
 
236
- # 2) Semantic checks (AST-first)
237
  try:
238
  sel = self._first_select(tree)
239
  if sel:
240
  has_group = self._has_group_by(tree)
241
  has_window = self._has_windowed_aggregate(tree)
242
  is_distinct = self._is_distinct_projection(tree)
243
-
244
  select_items = list(getattr(sel, "expressions", []) or [])
245
  any_agg = any(self._expr_contains_agg(it) for it in select_items)
246
-
247
- # More precise check for non-aggregate columns
248
- any_nonagg_col = False
249
- for item in select_items:
250
- # Check if this select item has columns but no aggregates
251
- has_cols = any(isinstance(n, exp.Column) for n in self._walk(item))
252
- has_aggs = self._expr_contains_agg(item)
253
- if has_cols and not has_aggs:
254
- any_nonagg_col = True
255
- break
256
-
257
- # Core rule: aggregate + non-aggregate column without GROUP BY is an issue,
258
- # unless DISTINCT or windowed aggregate makes it legitimate.
259
  if (
260
  any_agg
261
  and any_nonagg_col
@@ -264,72 +317,111 @@ class Verifier:
264
  verifier_failures_total.labels(reason="semantic_error").inc()
265
  issues.append("aggregation_without_group_by")
266
  except Exception as e:
267
- # Don't crash the verifier; surface a soft issue and let fallback run
268
  verifier_failures_total.labels(reason="semantic_error").inc()
269
  issues.append(f"semantic_check_error:{e!s}")
270
-
271
- # 3) Fallback textual scan only if AST didn't already flag
272
- if not any("aggregation_without_group_by" in i for i in issues):
273
- try:
274
- cleaned = self._clean_sql_for_fn_scan(sql)
275
- has_agg_call = bool(self._AGG_CALL_RE.search(cleaned))
276
- has_group_kw = re.search(r"\bgroup\s+by\b", cleaned, re.IGNORECASE)
277
- has_over_kw = re.search(r"\bover\s*\(", cleaned, re.IGNORECASE)
278
- has_distinct_kw = re.search(
279
- r"\bselect\s+distinct\b", cleaned, re.IGNORECASE
280
  )
281
-
282
- if has_agg_call and not (
283
- has_group_kw or has_over_kw or has_distinct_kw
284
- ):
285
- m_sel = re.search(
286
- r"\bselect\s+(?P<sel>.+?)\s+\bfrom\b",
287
- cleaned,
288
- re.IGNORECASE | re.DOTALL,
289
- )
290
- if m_sel:
291
- select_list = m_sel.group("sel")
292
- # a comma strongly suggests mixing aggregate and non-aggregate in projection
293
- if "," in select_list:
 
 
 
 
 
 
 
 
294
  verifier_failures_total.labels(
295
- reason="agg_without_group_by"
296
  ).inc()
297
  issues.append("aggregation_without_group_by")
298
- except Exception:
299
- # ignore fallback errors
300
- pass
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- # 4) Optional: cheap preview execution (adapter may be a stub in tests)
 
 
303
  try:
304
- exec_result = adapter.execute_preview(sql) if adapter else {"ok": True}
305
- ok_val = self._extract_ok(exec_result)
306
- if ok_val is False:
307
- err = self._extract_error(exec_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  verifier_failures_total.labels(reason="preview_exec_error").inc()
309
- issues.append(f"exec_error:{err}" if err else "exec_error")
310
  except Exception as e:
311
  verifier_failures_total.labels(reason="preview_exec_error").inc()
312
  issues.append(f"exec_exception:{e!s}")
313
 
314
- # 5) Final decision AFTER all checks (note: no early return before fallback)
315
- if issues:
316
- verifier_checks_total.labels(ok="false").inc()
317
- stage_duration_ms.labels("verifier").observe(_ms(t0) / 1.0)
 
 
 
 
 
 
 
 
 
 
 
318
  return StageResult(
319
  ok=False,
320
- error=issues,
321
  trace=StageTrace(
322
  stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
323
  ),
324
  )
325
 
326
- verifier_checks_total.labels(ok="true").inc()
327
- stage_duration_ms.labels("verifier").observe(_ms(t0) / 1.0)
328
- return StageResult(
329
- ok=True,
330
- data={"verified": True},
331
- trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
332
- )
333
-
334
- def run(self, *, sql: str, adapter: Any) -> StageResult:
335
- return self.verify(sql, adapter=adapter)
 
1
  from __future__ import annotations
 
2
  import re
3
  import time
4
+ from typing import Any, Iterable, List, Optional, Dict, Tuple
5
 
6
  import sqlglot
7
  from sqlglot import expressions as exp
 
9
  from nl2sql.types import StageResult, StageTrace
10
  from nl2sql.metrics import (
11
  verifier_checks_total,
 
12
  verifier_failures_total,
13
  )
14
 
15
 
16
  def _ms(t0: float) -> int:
17
+ """Return elapsed milliseconds since t0, as int."""
18
  return int((time.perf_counter() - t0) * 1000)
19
 
20
 
21
+ # ---------------- Small Levenshtein distance for schema matching ----------------
22
+ def _lev(a: str, b: str) -> int:
23
+ n = len(b)
24
+
25
+ dp = list(range(n + 1))
26
+ for i, ca in enumerate(a, 1):
27
+ prev, dp[0] = dp[0], i
28
+ for j, cb in enumerate(b, 1):
29
+ cur = min(
30
+ dp[j] + 1, # delete
31
+ dp[j - 1] + 1, # insert
32
+ prev + (0 if ca == cb else 1), # replace
33
+ )
34
+ prev, dp[j] = dp[j], cur
35
+ return dp[n]
36
+
37
+
38
+ def _closest(name: str, candidates: List[str]) -> Tuple[str, int]:
39
+ """Find the closest match (by edit distance) for a given name."""
40
+ best, dist = name, 10**9
41
+ for c in candidates:
42
+ d = _lev(name.lower(), c.lower())
43
+ if d < dist:
44
+ best, dist = c, d
45
+ return best, dist
46
+
47
+
48
+ def _maybe_singular(plural: str, tables: List[str]) -> Optional[str]:
49
+ """Simple singularization heuristic: 'singers' -> 'singer'."""
50
+ if plural.endswith("s"):
51
+ cand = plural[:-1]
52
+ if cand in tables:
53
+ return cand
54
+ return None
55
+
56
+
57
+ # ---------------- Verifier with schema-aware repair ----------------
58
  class Verifier:
59
  name = "verifier"
60
 
61
+ # Aggregate call detector used by both AST and regex fallbacks
62
  _AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
63
 
64
+ # Fast token sanity: require SELECT and FROM to exist in the cleaned SQL
65
+ _REQ_SELECT = re.compile(r"\bselect\b", re.IGNORECASE)
66
+ _REQ_FROM = re.compile(r"\bfrom\b", re.IGNORECASE)
67
+
68
+ # ---------- AST helpers ----------
69
  def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
70
+ """Depth-first traversal of a SQLGlot AST."""
71
  stack = [node]
72
  while stack:
73
  cur = stack.pop()
 
83
  stack.append(it)
84
 
85
  def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
86
+ """Return the first SELECT node from the AST (if any)."""
87
  for n in self._walk(tree):
88
  if isinstance(n, exp.Select):
89
  return n
 
91
 
92
  def _has_group_by(self, tree: exp.Expression) -> bool:
93
  sel = self._first_select(tree)
94
+ return bool(getattr(sel, "group", None)) if sel else False
 
 
 
95
 
96
  def _is_distinct_projection(self, tree: exp.Expression) -> bool:
97
  sel = self._first_select(tree)
98
  if not sel:
99
  return False
 
100
  if getattr(sel, "distinct", None):
101
  return True
102
  return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
103
 
104
  def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
 
105
  return any(isinstance(n, exp.Window) for n in self._walk(tree))
106
 
107
  def _expr_contains_agg(self, node: exp.Expression) -> bool:
108
+ """Return True if an expression contains an aggregate function."""
109
+ agg_names = {"count", "sum", "avg", "min", "max"}
110
  agg_type_names = (
111
  "Count",
112
  "Sum",
 
117
  "ArrayAgg",
118
  "StringAgg",
119
  )
120
+ agg_types = tuple(
121
+ t
122
+ for t in (getattr(exp, n, None) for n in agg_type_names)
123
+ if isinstance(t, type)
124
+ )
125
+
126
+ # AST type-based check (preferred)
127
+ if agg_types and any(isinstance(n, agg_types) for n in self._walk(node)):
 
128
  return True
129
 
130
+ # Fallback: function-like name check
131
  Anonymous = getattr(exp, "Anonymous", None)
132
  func_like = (exp.Func,) + ((Anonymous,) if isinstance(Anonymous, type) else ())
 
133
 
134
+ def _fname(n: exp.Expression) -> str:
135
+ nm = getattr(n, "name", None)
136
+ if isinstance(nm, str) and nm:
137
+ return nm.lower()
138
  this = getattr(n, "this", None)
139
  if isinstance(this, str):
140
  return this.lower()
 
144
  return (str(this) or "").lower()
145
 
146
  for n in self._walk(node):
147
+ if isinstance(n, func_like) and _fname(n) in agg_names:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  return True
149
  return False
150
 
 
151
  def _clean_sql_for_fn_scan(self, sql: str) -> str:
152
+ """Normalize SQL before scanning for function names or keywords."""
153
  s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
154
  s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
155
  s = re.sub(
156
  r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
157
+ ) # quoted strings
158
  s = re.sub(r"\s+", " ", s).strip()
159
  return s
160
 
161
+ # ---------------- Schema-Guard Repair ----------------
162
+ def _schema_dict(self, adapter: Any) -> Optional[Dict[str, List[str]]]:
163
+ """Fetch schema dict {table: [columns]} from adapter if available."""
164
+ if not adapter:
165
+ return None
166
+ get = getattr(adapter, "schema_dict", None)
167
+ if callable(get):
168
+ try:
169
+ d = get()
170
+ if isinstance(d, dict):
171
+ return {str(k): list(v) for k, v in d.items()}
172
+ except Exception:
173
+ return None
174
  return None
175
 
176
+ def _repair_with_schema(
177
+ self, sql: str, schema: Dict[str, List[str]]
178
+ ) -> Tuple[str, bool, List[str]]:
179
+ """Try to fix table/column names using schema similarity (singularize + closest edit-distance <= 2)."""
180
+ notes: List[str] = []
181
+ try:
182
+ ast = sqlglot.parse_one(sql)
183
+ except Exception as e:
184
+ return sql, False, [f"parse_error:{e!s}"]
185
+
186
+ tables = list(schema.keys())
187
+ changed = False
188
+
189
+ # Fix table names
190
+ def _fix_table(node: exp.Expression) -> exp.Expression:
191
+ nonlocal changed
192
+ if isinstance(node, exp.Table):
193
+ orig = node.name
194
+ if orig in schema:
195
+ return node
196
+ s1 = _maybe_singular(orig, tables)
197
+ if s1:
198
+ changed = True
199
+ return exp.Table(this=sqlglot.to_identifier(s1))
200
+ best, dist = _closest(orig, tables)
201
+ if dist <= 2:
202
+ changed = True
203
+ return exp.Table(this=sqlglot.to_identifier(best))
204
+ return node
205
+
206
+ ast = ast.transform(_fix_table)
207
+
208
+ # Fix column names
209
+ def _fix_col(node: exp.Expression) -> exp.Expression:
210
+ nonlocal changed
211
+ if isinstance(node, exp.Column):
212
+ name = node.name
213
+ if not name:
214
+ return node
215
+ tbl = node.table
216
+ if tbl and tbl in schema:
217
+ candidates = schema[tbl]
218
+ else:
219
+ candidates = [c for cols in schema.values() for c in cols]
220
+ if name in candidates:
221
+ return node
222
+ best, dist = _closest(name, candidates) if candidates else (name, 99)
223
+ if dist <= 2:
224
+ changed = True
225
+ node.set("this", sqlglot.to_identifier(best))
226
+ return node
227
+
228
+ ast = ast.transform(_fix_col)
229
+
230
+ if not changed:
231
+ return sql, True, notes
232
 
233
+ try:
234
+ repaired = ast.sql(dialect="sqlite")
235
+ except Exception as e:
236
+ return sql, False, notes + [f"rebuild_error:{e!s}"]
237
+
238
+ notes.append("schema_guard_repair")
239
+ return repaired, True, notes
240
+
241
+ # ---------------- Main verifier logic ----------------
242
+ def verify(
243
+ self, sql: str, *, exec_result: Any = None, adapter: Any = None
244
+ ) -> StageResult:
245
+ """
246
+ Verify syntax, basic semantics, and optionally schema correctness and preview-execution.
247
+
248
+ Returns:
249
+ StageResult with:
250
+ - ok: boolean
251
+ - data: may include {"verified": True, "sql": <repaired_sql>}
252
+ - trace: StageTrace(stage="verifier", duration_ms=...)
253
+ """
254
  t0 = time.perf_counter()
255
  issues: List[str] = []
256
+ repaired_sql = None
257
 
258
+ # 0) Fast token sanity: must contain SELECT and FROM (handles typos like SELCT/FRM).
259
+ sql_scan = self._clean_sql_for_fn_scan(sql)
260
+ if not self._REQ_SELECT.search(sql_scan) or not self._REQ_FROM.search(sql_scan):
261
+ verifier_checks_total.labels(ok="false").inc()
262
+ verifier_failures_total.labels(reason="parse_error").inc()
263
+ return StageResult(
264
+ ok=False,
265
+ error=["parse_error"],
266
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
267
+ )
268
 
269
+ # 1) Syntax validation via sqlglot
270
+ try:
271
+ tree = sqlglot.parse_one(sql, read=None)
272
  if tree is None:
273
  return StageResult(
274
  ok=False,
275
  error=["parse_error"],
276
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
277
  )
 
 
 
278
  tree_type = type(tree).__name__
 
 
 
279
  if tree_type in ("Command", "Unknown"):
280
  verifier_checks_total.labels(ok="false").inc()
281
  verifier_failures_total.labels(reason="parse_error").inc()
 
284
  error=["parse_error"],
285
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
286
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  except Exception:
288
  verifier_checks_total.labels(ok="false").inc()
289
  verifier_failures_total.labels(reason="parse_error").inc()
 
293
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
294
  )
295
 
296
+ # 2) Semantic rule: avoid aggregate + non-aggregate mix without GROUP BY (unless DISTINCT/window)
297
  try:
298
  sel = self._first_select(tree)
299
  if sel:
300
  has_group = self._has_group_by(tree)
301
  has_window = self._has_windowed_aggregate(tree)
302
  is_distinct = self._is_distinct_projection(tree)
 
303
  select_items = list(getattr(sel, "expressions", []) or [])
304
  any_agg = any(self._expr_contains_agg(it) for it in select_items)
305
+ any_nonagg_col = any(
306
+ (
307
+ any(isinstance(n, exp.Column) for n in self._walk(it))
308
+ and not self._expr_contains_agg(it)
309
+ )
310
+ for it in select_items
311
+ )
 
 
 
 
 
 
312
  if (
313
  any_agg
314
  and any_nonagg_col
 
317
  verifier_failures_total.labels(reason="semantic_error").inc()
318
  issues.append("aggregation_without_group_by")
319
  except Exception as e:
 
320
  verifier_failures_total.labels(reason="semantic_error").inc()
321
  issues.append(f"semantic_check_error:{e!s}")
322
+ # 2b) Regex fallback for aggregate + non-aggregate without GROUP BY.
323
+ # Skip if DISTINCT or any WINDOW (OVER ...) is present in the SELECT list.
324
+ try:
325
+ low = sql_scan.lower()
326
+ if "group by" not in low and "distinct" not in low:
327
+ m = re.search(
328
+ r"select\s+(?P<sel>.+?)\s+from\b",
329
+ sql_scan,
330
+ flags=re.IGNORECASE | re.DOTALL,
 
331
  )
332
+ if m:
333
+ sel_clause = m.group("sel")
334
+ # If window functions are present, allow (COUNT(*) OVER (...), etc.)
335
+ if re.search(r"\bover\b", sel_clause, flags=re.IGNORECASE):
336
+ pass # windowed aggregates are acceptable without GROUP BY
337
+ else:
338
+ has_agg = bool(self._AGG_CALL_RE.search(sel_clause))
339
+ # Heuristic: presence of a comma OR a bare identifier besides pure aggregate-only select
340
+ has_bare_col = "," in sel_clause or (
341
+ bool(re.search(r"\b[a-zA-Z_][\w.]*\b", sel_clause))
342
+ and not re.fullmatch(
343
+ r"\s*(count|sum|avg|min|max)\s*\([^)]*\)\s*",
344
+ sel_clause,
345
+ flags=re.IGNORECASE,
346
+ )
347
+ )
348
+ if (
349
+ has_agg
350
+ and has_bare_col
351
+ and "aggregation_without_group_by" not in issues
352
+ ):
353
  verifier_failures_total.labels(
354
+ reason="semantic_error"
355
  ).inc()
356
  issues.append("aggregation_without_group_by")
357
+ except Exception:
358
+ # Non-fatal; AST path already attempted.
359
+ pass
360
+
361
+ # 3) Schema-based auto-repair (optional)
362
+ schema = self._schema_dict(adapter)
363
+ if schema:
364
+ fixed, ok_fix, notes = self._repair_with_schema(sql, schema)
365
+ if ok_fix is True and fixed != sql:
366
+ repaired_sql = fixed
367
+ if notes:
368
+ issues.extend(
369
+ [f"note:{n}" for n in notes if not n.startswith("parse_error")]
370
+ )
371
 
372
+ # 4) Preview execution check:
373
+ # - If exec_result is provided, use it directly
374
+ # - Otherwise, if adapter has execute_preview, run it
375
  try:
376
+ if exec_result is not None:
377
+ er = exec_result
378
+ elif adapter is not None and hasattr(adapter, "execute_preview"):
379
+ er = adapter.execute_preview(repaired_sql or sql)
380
+ else:
381
+ er = {"ok": True}
382
+
383
+ ok_val = (
384
+ isinstance(er, dict) and isinstance(er.get("ok"), bool) and er["ok"]
385
+ )
386
+ if not ok_val:
387
+ msg = None
388
+ if isinstance(er, dict):
389
+ for k in ("error", "message", "detail"):
390
+ if k in er and er[k]:
391
+ msg = str(er[k])
392
+ break
393
  verifier_failures_total.labels(reason="preview_exec_error").inc()
394
+ issues.append(f"exec_error:{msg or 'preview_failed'}")
395
  except Exception as e:
396
  verifier_failures_total.labels(reason="preview_exec_error").inc()
397
  issues.append(f"exec_exception:{e!s}")
398
 
399
+ # 5) Final result and trace
400
+ is_ok: bool = (not issues) or all(i.startswith("note:") for i in issues)
401
+ ok_label: str = "true" if is_ok else "false"
402
+ verifier_checks_total.labels(ok=ok_label).inc()
403
+
404
+ if is_ok:
405
+ data: Dict[str, Any] = {"verified": True}
406
+ if repaired_sql:
407
+ data["sql"] = repaired_sql
408
+ return StageResult(
409
+ ok=True,
410
+ data=data,
411
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
412
+ )
413
+ else:
414
  return StageResult(
415
  ok=False,
416
+ error=[i for i in issues if not i.startswith("note:")],
417
  trace=StageTrace(
418
  stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
419
  ),
420
  )
421
 
422
+ # Public alias for backward compatibility
423
+ def run(
424
+ self, *, sql: str, exec_result: Any = None, adapter: Any = None
425
+ ) -> StageResult:
426
+ """Back-compat wrapper around verify()."""
427
+ return self.verify(sql, exec_result=exec_result, adapter=adapter)