Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
8b2d603
1
Parent(s):
e3e0ac5
perf(planner): trim relevant tables (+cache) to cut latency; keep repair loop & rich traces
Browse files- benchmarks/results_pro/20251109-123149/eval.jsonl +5 -0
- benchmarks/results_pro/20251109-123149/latency_histogram.png +0 -0
- benchmarks/results_pro/20251109-123149/latency_per_stage.png +0 -0
- benchmarks/results_pro/20251109-123149/metrics_overview.png +0 -0
- benchmarks/results_pro/20251109-123149/results.csv +6 -0
- benchmarks/results_pro/20251109-123149/summary.json +21 -0
- nl2sql/planner.py +103 -7
benchmarks/results_pro/20251109-123149/eval.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer", "ok": true, "latency_ms": 11573, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7978, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 3588, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 796, "token_out": 19, "cost_usd": 0.0001308, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 3, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"sql_length": 27, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": false, "verified": true}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
|
| 2 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer", "ok": true, "latency_ms": 9087, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7653, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 1432, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 789, "token_out": 19, "cost_usd": 0.00012974999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"sql_length": 27, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": false, "verified": true}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
|
| 3 |
+
{"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}, "skipped": false}]}
|
| 4 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc", "ok": true, "latency_ms": 10200, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 8373, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 1824, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 801, "token_out": 37, "cost_usd": 0.00014235, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 55}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"sql_length": 55, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": false, "mixes_cols": false, "verified": true}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
|
| 5 |
+
{"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": false, "latency_ms": 20765, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 9562, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 4303, "summary": "ok", "notes": {"rationale_len": 67}, "token_in": 827, "token_out": 42, "cost_usd": 0.00014924999999999997, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "failed", "notes": {"sql_length": 72, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": true, "verified": false}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 5379, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 328, "token_out": 24, "cost_usd": 6.36e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 80}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "failed", "notes": {"sql_length": 80, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": true, "verified": false}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1516, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 332, "token_out": 25, "cost_usd": 6.48e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "failed", "notes": {"sql_length": 72, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": true, "verified": false}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
|
benchmarks/results_pro/20251109-123149/latency_histogram.png
ADDED
|
benchmarks/results_pro/20251109-123149/latency_per_stage.png
ADDED
|
benchmarks/results_pro/20251109-123149/metrics_overview.png
ADDED
|
benchmarks/results_pro/20251109-123149/results.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
db_id,query,ok,em,sm,exec_acc,latency_ms
|
| 2 |
+
concert_singer,"How many singers do we have?",✅,1.0,1.0,1.0,11573
|
| 3 |
+
concert_singer,"What is the total number of singers?",✅,1.0,1.0,1.0,9087
|
| 4 |
+
concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",✅,0.0,0.0,0.0,0
|
| 5 |
+
concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",✅,0.0,1.0,1.0,10200
|
| 6 |
+
concert_singer,"What is the average, minimum, and maximum age of all singers from France?",❌,0.0,1.0,1.0,20765
|
benchmarks/results_pro/20251109-123149/summary.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-11-09T12:32:41",
|
| 3 |
+
"split": "dev",
|
| 4 |
+
"config": "configs/sqlite_pipeline.yaml",
|
| 5 |
+
"total": 5,
|
| 6 |
+
"success": 4,
|
| 7 |
+
"success_rate": 0.8,
|
| 8 |
+
"avg_latency_ms": 10325.0,
|
| 9 |
+
"p50_latency_ms": 10200.0,
|
| 10 |
+
"p95_latency_ms": 18926.6,
|
| 11 |
+
"EM": 0.4,
|
| 12 |
+
"SM": 0.8,
|
| 13 |
+
"ExecAcc": 0.8,
|
| 14 |
+
"detector_avg_ms": 0.0,
|
| 15 |
+
"planner_avg_ms": 8391.5,
|
| 16 |
+
"generator_avg_ms": 2786.75,
|
| 17 |
+
"safety_avg_ms": 1.5,
|
| 18 |
+
"executor_avg_ms": 1.33,
|
| 19 |
+
"verifier_avg_ms": 0.0,
|
| 20 |
+
"repair_avg_ms": 3447.5
|
| 21 |
+
}
|
nl2sql/planner.py
CHANGED
|
@@ -1,22 +1,118 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
|
|
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
class Planner:
|
| 7 |
-
"""Planner wrapper around the LLM provider.
|
| 8 |
|
| 9 |
-
|
| 10 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def __init__(self, *, llm, model_id: str | None = None) -> None:
|
| 13 |
self.llm = llm
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
return {
|
| 21 |
"plan": plan_text,
|
| 22 |
"usage": {
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, Dict, List, Tuple, Optional
|
| 5 |
|
| 6 |
+
__all__ = ["Planner"]
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# --------- Heuristic schema trimming (safe, mypy-clean) ---------
|
| 10 |
+
def _tokenize_lower(s: str) -> List[str]:
|
| 11 |
+
return re.findall(r"[a-z_]+", (s or "").lower())
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
|
| 15 |
"""
|
| 16 |
+
Parse plain-text schema into [(table_name, lines)] blocks,
|
| 17 |
+
supporting both 'Table: name' and 'CREATE TABLE name (' styles.
|
| 18 |
+
"""
|
| 19 |
+
blocks: List[Tuple[str, List[str]]] = []
|
| 20 |
+
cur_name: Optional[str] = None
|
| 21 |
+
cur_lines: List[str] = []
|
| 22 |
+
|
| 23 |
+
def _flush() -> None:
|
| 24 |
+
nonlocal cur_name, cur_lines
|
| 25 |
+
if cur_name is not None and cur_lines:
|
| 26 |
+
blocks.append((cur_name, cur_lines[:]))
|
| 27 |
+
cur_name, cur_lines = None, []
|
| 28 |
+
|
| 29 |
+
for line in (schema_text or "").splitlines():
|
| 30 |
+
m = re.search(r"Table:\s*(\w+)", line, flags=re.IGNORECASE)
|
| 31 |
+
m2 = re.search(r"CREATE\s+TABLE\s+(\w+)\s*\(", line, flags=re.IGNORECASE)
|
| 32 |
+
|
| 33 |
+
started = False
|
| 34 |
+
name: Optional[str] = None
|
| 35 |
+
if m is not None:
|
| 36 |
+
name = m.group(1)
|
| 37 |
+
started = True
|
| 38 |
+
elif m2 is not None:
|
| 39 |
+
name = m2.group(1)
|
| 40 |
+
started = True
|
| 41 |
+
|
| 42 |
+
if started and name:
|
| 43 |
+
_flush()
|
| 44 |
+
cur_name = name
|
| 45 |
+
cur_lines.append(line)
|
| 46 |
+
else:
|
| 47 |
+
if cur_name is not None:
|
| 48 |
+
cur_lines.append(line)
|
| 49 |
+
|
| 50 |
+
if cur_name is not None and line.strip().endswith(");"):
|
| 51 |
+
_flush()
|
| 52 |
+
|
| 53 |
+
_flush()
|
| 54 |
+
return blocks
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str:
|
| 58 |
+
"""Keep up to k tables with highest lexical overlap with the question."""
|
| 59 |
+
try:
|
| 60 |
+
blocks = _table_blocks(schema_text)
|
| 61 |
+
if not blocks:
|
| 62 |
+
return schema_text
|
| 63 |
+
|
| 64 |
+
q_toks = set(_tokenize_lower(question))
|
| 65 |
+
scored: List[Tuple[int, str, List[str]]] = []
|
| 66 |
+
for name, lines in blocks:
|
| 67 |
+
score = sum(1 for w in _tokenize_lower(name) if w in q_toks)
|
| 68 |
+
cols_line = " ".join(lines)
|
| 69 |
+
cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line)
|
| 70 |
+
score += min(2, sum(1 for c in cols if c.lower() in q_toks))
|
| 71 |
+
scored.append((score, name, lines))
|
| 72 |
+
|
| 73 |
+
scored.sort(key=lambda t: t[0], reverse=True)
|
| 74 |
+
keep = [b for b in scored[: max(1, k)] if b[0] > 0]
|
| 75 |
+
if not keep:
|
| 76 |
+
keep = scored[: max(1, k)]
|
| 77 |
+
|
| 78 |
+
out_lines: List[str] = []
|
| 79 |
+
for _, _, lines in keep:
|
| 80 |
+
out_lines.extend(lines)
|
| 81 |
+
if lines and lines[-1].strip() != "":
|
| 82 |
+
out_lines.append("")
|
| 83 |
+
trimmed = "\n".join(out_lines).strip()
|
| 84 |
+
return trimmed if trimmed else schema_text
|
| 85 |
+
except Exception:
|
| 86 |
+
return schema_text
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ------------------------------ Planner ------------------------------
|
| 90 |
+
class Planner:
|
| 91 |
+
"""Planner wrapper around the LLM provider."""
|
| 92 |
|
| 93 |
def __init__(self, *, llm, model_id: str | None = None) -> None:
|
| 94 |
self.llm = llm
|
| 95 |
+
# ensure model_id is always a str (for mypy)
|
| 96 |
+
self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
|
| 97 |
+
# in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
|
| 98 |
+
self._plan_cache: dict[tuple[str, int, int], tuple[str, int, int, float]] = {}
|
| 99 |
|
| 100 |
def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
|
| 101 |
+
trimmed = _pick_relevant_tables(schema_preview or "", user_query or "", k=3)
|
| 102 |
+
|
| 103 |
+
key: tuple[str, int, int] = (
|
| 104 |
+
self.model_id,
|
| 105 |
+
hash(user_query or ""),
|
| 106 |
+
hash(trimmed),
|
| 107 |
)
|
| 108 |
+
if key in self._plan_cache:
|
| 109 |
+
plan_text, pin, pout, cost = self._plan_cache[key]
|
| 110 |
+
else:
|
| 111 |
+
plan_text, pin, pout, cost = self.llm.plan(
|
| 112 |
+
user_query=user_query, schema_preview=trimmed
|
| 113 |
+
)
|
| 114 |
+
self._plan_cache[key] = (plan_text, pin, pout, cost)
|
| 115 |
+
|
| 116 |
return {
|
| 117 |
"plan": plan_text,
|
| 118 |
"usage": {
|