Melika Kheirieh commited on
Commit
8b2d603
·
1 Parent(s): e3e0ac5

perf(planner): trim relevant tables (+cache) to cut latency; keep repair loop & rich traces

Browse files
benchmarks/results_pro/20251109-123149/eval.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"source": "spider", "db_id": "concert_singer", "query": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer", "ok": true, "latency_ms": 11573, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7978, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 3588, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 796, "token_out": 19, "cost_usd": 0.0001308, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 3, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"sql_length": 27, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": false, "verified": true}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
2
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "select count(*) from singer", "ok": true, "latency_ms": 9087, "em": 1.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 7653, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 1432, "summary": "ok", "notes": {"rationale_len": 30}, "token_in": 789, "token_out": 19, "cost_usd": 0.00012974999999999998, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 1, "sql_length": 27}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"sql_length": 27, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": false, "verified": true}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
3
+ {"source": "spider", "db_id": "concert_singer", "query": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "", "ok": true, "latency_ms": 0, "em": 0.0, "sm": 0.0, "exec_acc": 0.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "ambiguous", "notes": {"ambiguous": true, "questions_len": 1}, "skipped": false}]}
4
+ {"source": "spider", "db_id": "concert_singer", "query": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "select Name, Country, Age from singer order by Age desc", "ok": true, "latency_ms": 10200, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 8373, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 1824, "summary": "ok", "notes": {"rationale_len": 85}, "token_in": 801, "token_out": 37, "cost_usd": 0.00014235, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 2, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 6, "col_count": 3, "sql_length": 55}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "ok", "notes": {"sql_length": 55, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": false, "mixes_cols": false, "verified": true}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
5
+ {"source": "spider", "db_id": "concert_singer", "query": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "select avg(Age), min(Age), max(Age) from singer where Country = 'France'", "ok": false, "latency_ms": 20765, "em": 0.0, "sm": 1.0, "exec_acc": 1.0, "error": null, "trace": [{"stage": "detector", "duration_ms": 0, "summary": "clear", "notes": {"ambiguous": false, "questions_len": 0}, "skipped": false}, {"stage": "planner", "duration_ms": 9562, "summary": "ok", "notes": {}, "skipped": false}, {"stage": "generator", "duration_ms": 4303, "summary": "ok", "notes": {"rationale_len": 67}, "token_in": 827, "token_out": 42, "cost_usd": 0.00014924999999999997, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "safety", "duration_ms": 1, "summary": "ok", "notes": {}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "failed", "notes": {"sql_length": 72, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": true, "verified": false}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 5379, "summary": "ok", "notes": {"old_sql_len": 72, "new_sql_len": 80}, "token_in": 328, "token_out": 24, "cost_usd": 6.36e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 80}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "failed", "notes": {"sql_length": 80, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": true, "verified": false}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "repair", "duration_ms": 1516, "summary": "ok", "notes": {"old_sql_len": 80, "new_sql_len": 72}, "token_in": 332, "token_out": 25, "cost_usd": 6.48e-05, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "executor", "duration_ms": 1, "summary": "ok", "notes": {"row_count": 1, "col_count": 3, "sql_length": 72}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}, {"stage": "verifier", "duration_ms": 0, "summary": "failed", "notes": {"sql_length": 72, "has_select": true, "has_from": true, "has_over": false, "has_group_by": false, "has_distinct": false, "has_aggregate": true, "mixes_cols": true, "verified": false}, "token_in": null, "token_out": null, "cost_usd": null, "sql_length": null, "row_count": null, "verified": null, "error_type": null, "repair_attempts": null, "skipped": false}]}
benchmarks/results_pro/20251109-123149/latency_histogram.png ADDED
benchmarks/results_pro/20251109-123149/latency_per_stage.png ADDED
benchmarks/results_pro/20251109-123149/metrics_overview.png ADDED
benchmarks/results_pro/20251109-123149/results.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ db_id,query,ok,em,sm,exec_acc,latency_ms
2
+ concert_singer,"How many singers do we have?",✅,1.0,1.0,1.0,11573
3
+ concert_singer,"What is the total number of singers?",✅,1.0,1.0,1.0,9087
4
+ concert_singer,"Show name, country, age for all singers ordered by age from the oldest to the youngest.",✅,0.0,0.0,0.0,0
5
+ concert_singer,"What are the names, countries, and ages for every singer in descending order of age?",✅,0.0,1.0,1.0,10200
6
+ concert_singer,"What is the average, minimum, and maximum age of all singers from France?",❌,0.0,1.0,1.0,20765
benchmarks/results_pro/20251109-123149/summary.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-11-09T12:32:41",
3
+ "split": "dev",
4
+ "config": "configs/sqlite_pipeline.yaml",
5
+ "total": 5,
6
+ "success": 4,
7
+ "success_rate": 0.8,
8
+ "avg_latency_ms": 10325.0,
9
+ "p50_latency_ms": 10200.0,
10
+ "p95_latency_ms": 18926.6,
11
+ "EM": 0.4,
12
+ "SM": 0.8,
13
+ "ExecAcc": 0.8,
14
+ "detector_avg_ms": 0.0,
15
+ "planner_avg_ms": 8391.5,
16
+ "generator_avg_ms": 2786.75,
17
+ "safety_avg_ms": 1.5,
18
+ "executor_avg_ms": 1.33,
19
+ "verifier_avg_ms": 0.0,
20
+ "repair_avg_ms": 3447.5
21
+ }
nl2sql/planner.py CHANGED
@@ -1,22 +1,118 @@
1
  from __future__ import annotations
2
 
3
- from typing import Dict, Any
 
4
 
 
 
 
 
 
 
5
 
6
- class Planner:
7
- """Planner wrapper around the LLM provider.
8
 
9
- The factory constructs it with `Planner(llm=llm)`, so we accept `llm` here.
10
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __init__(self, *, llm, model_id: str | None = None) -> None:
13
  self.llm = llm
14
- self.model_id = model_id or getattr(llm, "model", "unknown")
 
 
 
15
 
16
  def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
17
- plan_text, pin, pout, cost = self.llm.plan(
18
- user_query=user_query, schema_preview=schema_preview
 
 
 
 
19
  )
 
 
 
 
 
 
 
 
20
  return {
21
  "plan": plan_text,
22
  "usage": {
 
1
  from __future__ import annotations
2
 
3
+ import re
4
+ from typing import Any, Dict, List, Tuple, Optional
5
 
6
+ __all__ = ["Planner"]
7
+
8
+
9
+ # --------- Heuristic schema trimming (safe, mypy-clean) ---------
10
+ def _tokenize_lower(s: str) -> List[str]:
11
+ return re.findall(r"[a-z_]+", (s or "").lower())
12
 
 
 
13
 
14
+ def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
15
  """
16
+ Parse plain-text schema into [(table_name, lines)] blocks,
17
+ supporting both 'Table: name' and 'CREATE TABLE name (' styles.
18
+ """
19
+ blocks: List[Tuple[str, List[str]]] = []
20
+ cur_name: Optional[str] = None
21
+ cur_lines: List[str] = []
22
+
23
+ def _flush() -> None:
24
+ nonlocal cur_name, cur_lines
25
+ if cur_name is not None and cur_lines:
26
+ blocks.append((cur_name, cur_lines[:]))
27
+ cur_name, cur_lines = None, []
28
+
29
+ for line in (schema_text or "").splitlines():
30
+ m = re.search(r"Table:\s*(\w+)", line, flags=re.IGNORECASE)
31
+ m2 = re.search(r"CREATE\s+TABLE\s+(\w+)\s*\(", line, flags=re.IGNORECASE)
32
+
33
+ started = False
34
+ name: Optional[str] = None
35
+ if m is not None:
36
+ name = m.group(1)
37
+ started = True
38
+ elif m2 is not None:
39
+ name = m2.group(1)
40
+ started = True
41
+
42
+ if started and name:
43
+ _flush()
44
+ cur_name = name
45
+ cur_lines.append(line)
46
+ else:
47
+ if cur_name is not None:
48
+ cur_lines.append(line)
49
+
50
+ if cur_name is not None and line.strip().endswith(");"):
51
+ _flush()
52
+
53
+ _flush()
54
+ return blocks
55
+
56
+
57
+ def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str:
58
+ """Keep up to k tables with highest lexical overlap with the question."""
59
+ try:
60
+ blocks = _table_blocks(schema_text)
61
+ if not blocks:
62
+ return schema_text
63
+
64
+ q_toks = set(_tokenize_lower(question))
65
+ scored: List[Tuple[int, str, List[str]]] = []
66
+ for name, lines in blocks:
67
+ score = sum(1 for w in _tokenize_lower(name) if w in q_toks)
68
+ cols_line = " ".join(lines)
69
+ cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line)
70
+ score += min(2, sum(1 for c in cols if c.lower() in q_toks))
71
+ scored.append((score, name, lines))
72
+
73
+ scored.sort(key=lambda t: t[0], reverse=True)
74
+ keep = [b for b in scored[: max(1, k)] if b[0] > 0]
75
+ if not keep:
76
+ keep = scored[: max(1, k)]
77
+
78
+ out_lines: List[str] = []
79
+ for _, _, lines in keep:
80
+ out_lines.extend(lines)
81
+ if lines and lines[-1].strip() != "":
82
+ out_lines.append("")
83
+ trimmed = "\n".join(out_lines).strip()
84
+ return trimmed if trimmed else schema_text
85
+ except Exception:
86
+ return schema_text
87
+
88
+
89
+ # ------------------------------ Planner ------------------------------
90
+ class Planner:
91
+ """Planner wrapper around the LLM provider."""
92
 
93
  def __init__(self, *, llm, model_id: str | None = None) -> None:
94
  self.llm = llm
95
+ # ensure model_id is always a str (for mypy)
96
+ self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
97
+ # in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
98
+ self._plan_cache: dict[tuple[str, int, int], tuple[str, int, int, float]] = {}
99
 
100
  def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
101
+ trimmed = _pick_relevant_tables(schema_preview or "", user_query or "", k=3)
102
+
103
+ key: tuple[str, int, int] = (
104
+ self.model_id,
105
+ hash(user_query or ""),
106
+ hash(trimmed),
107
  )
108
+ if key in self._plan_cache:
109
+ plan_text, pin, pout, cost = self._plan_cache[key]
110
+ else:
111
+ plan_text, pin, pout, cost = self.llm.plan(
112
+ user_query=user_query, schema_preview=trimmed
113
+ )
114
+ self._plan_cache[key] = (plan_text, pin, pout, cost)
115
+
116
  return {
117
  "plan": plan_text,
118
  "usage": {