Melika Kheirieh commited on
Commit
eee3f75
·
1 Parent(s): a337fad

fix(types): resolve mypy errors and make pytest pass

Browse files
Files changed (3) hide show
  1. app/schemas.py +3 -3
  2. benchmarks/evaluate_spider.py +211 -55
  3. benchmarks/run.py +34 -10
app/schemas.py CHANGED
@@ -1,5 +1,5 @@
1
- from pydantic import BaseModel
2
- from typing import List, Optional, Any, Dict, Union
3
 
4
 
5
  class NL2SQLRequest(BaseModel):
@@ -21,7 +21,7 @@ class NL2SQLResponse(BaseModel):
21
  ambiguous: bool = False
22
  sql: Optional[str] = None
23
  rationale: Optional[str] = None
24
- traces: List[Union[TraceModel, dict]] = []
25
 
26
 
27
  class ClarifyResponse(BaseModel):
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional, Any, Dict, Mapping, Sequence
3
 
4
 
5
  class NL2SQLRequest(BaseModel):
 
21
  ambiguous: bool = False
22
  sql: Optional[str] = None
23
  rationale: Optional[str] = None
24
+ traces: Sequence[TraceModel | Mapping[str, Any]] = Field(default_factory=list)
25
 
26
 
27
  class ClarifyResponse(BaseModel):
benchmarks/evaluate_spider.py CHANGED
@@ -1,38 +1,96 @@
1
  from __future__ import annotations
2
- import time
3
  import json
4
  import subprocess
 
5
  from pathlib import Path
6
- from tqdm import tqdm
7
 
8
- from app import get_schema_preview, on_generate_query, make_sql_chain
9
  from langchain_community.utilities import SQLDatabase
10
- from benchmarks import load_spider_sqlite
11
-
12
  from sqlglot import parse_one, exp
13
  from sqlglot.errors import ParseError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  LOG_DIR = Path("logs/spider_eval")
16
  LOG_DIR.mkdir(parents=True, exist_ok=True)
17
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def normalize_sql(sql: str) -> str:
20
- # نسخه ساده؛ می‌تونی قوی‌ترش کنی با پارس + بازسازی
21
  return " ".join(sql.lower().strip().split())
22
 
23
 
24
- def compare_results(pred_rows, gold_rows):
 
 
25
  if pred_rows is None or gold_rows is None:
26
  return False
27
- # اگر ترتیب مهم نیست
28
  return set(pred_rows) == set(gold_rows)
29
 
30
 
31
- def try_execute_sql(sql_db, sql, timeout: float = None):
 
 
 
 
32
  start = time.time()
33
  try:
34
- rows = sql_db.run(sql)
 
 
 
 
 
 
 
 
 
 
35
  return rows, time.time() - start, None
 
36
  except Exception as e:
37
  return None, time.time() - start, str(e)
38
 
@@ -44,7 +102,7 @@ def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
44
  except Exception:
45
  return False
46
 
47
- def normalize_ast(node: exp.Expression):
48
  for name, arg in node.args.items():
49
  if isinstance(arg, list):
50
  arg.sort(key=lambda x: str(x))
@@ -73,19 +131,7 @@ def get_git_commit_hash() -> str:
73
  return "UNKNOWN"
74
 
75
 
76
- FORBIDDEN_NODES = (
77
- exp.Insert,
78
- exp.Delete,
79
- exp.Update,
80
- exp.Drop,
81
- exp.Alter,
82
- exp.Attach,
83
- exp.Pragma,
84
- exp.Create,
85
- )
86
-
87
-
88
- def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
89
  try:
90
  ast = parse_one(sql, read=dialect)
91
  except ParseError:
@@ -98,7 +144,104 @@ def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
98
  return True
99
 
100
 
101
- def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  data = load_spider_sqlite(split)
103
  if len(data) < limit:
104
  limit = len(data)
@@ -113,7 +256,7 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
113
  results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
114
  metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
115
 
116
- done = set()
117
  if resume and results_fn.exists():
118
  with results_fn.open("r", encoding="utf-8") as f:
119
  for line in f:
@@ -126,6 +269,8 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
126
  pass
127
 
128
  write_header = not results_fn.exists()
 
 
129
  with (
130
  results_fn.open("a", encoding="utf-8") as fout,
131
  pred_txt.open("a", encoding="utf-8") as fpred,
@@ -141,25 +286,48 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
141
  fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
142
  fout.flush()
143
 
144
- agg = []
145
  for ex in tqdm(data):
146
  key = (ex.db_id, ex.question)
147
  if resume and key in done:
148
  continue
149
 
150
  db_path = str(ex.db_path)
151
- schema = get_schema_preview(db_path, 0)
152
  sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
153
- chain = make_sql_chain(sql_db)
154
- state = {
155
- "db_path": db_path,
156
- "sql_db": sql_db,
157
- "schema_text": schema,
158
- "chain": chain,
159
- }
160
 
161
  t0 = time.time()
162
- msg, sql, output = on_generate_query(ex.question, 1000, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  gen_time = time.time() - t0
164
 
165
  safe_flag = is_safe_sql(sql)
@@ -197,21 +365,9 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
197
  gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
198
 
199
  skip = gold_error is not None
200
-
201
- em = False
202
- if not skip:
203
- try:
204
- em = normalize_sql(sql) == normalize_sql(ex.gold_sql)
205
- except Exception:
206
- pass
207
-
208
- em_struct = False
209
- if not skip:
210
- em_struct = exact_match_structural(sql, ex.gold_sql)
211
-
212
- exec_acc = False
213
- if not skip:
214
- exec_acc = compare_results(pred_rows, gold_rows)
215
 
216
  rec = {
217
  "db_id": ex.db_id,
@@ -231,7 +387,6 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
231
  "execution_accuracy": exec_acc,
232
  "safe_check_failed": False,
233
  }
234
-
235
  fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
236
  fout.flush()
237
  fpred.write(f"{sql}\t{ex.db_id}\n")
@@ -246,7 +401,7 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
246
  valid = [
247
  r
248
  for r in agg
249
- if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None
250
  ]
251
  total_valid = len(valid)
252
  total_all = len(agg)
@@ -263,8 +418,8 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
263
  if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
264
  )
265
  safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
266
- avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
267
- avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
268
 
269
  metrics = {
270
  "commit_hash": commit_hash,
@@ -282,6 +437,7 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
282
  "run_id": start_ts,
283
  }
284
 
 
285
  with metrics_fn.open("w", encoding="utf-8") as fm:
286
  json.dump(metrics, fm, ensure_ascii=False, indent=2)
287
 
 
1
  from __future__ import annotations
2
+
3
  import json
4
  import subprocess
5
+ import time
6
  from pathlib import Path
7
+ from typing import Any, Iterable, Optional, Tuple, cast
8
 
9
+ from tqdm import tqdm
10
  from langchain_community.utilities import SQLDatabase
 
 
11
  from sqlglot import parse_one, exp
12
  from sqlglot.errors import ParseError
13
+ from sqlalchemy import create_engine, inspect
14
+ from spider_loader import load_spider_sqlite
15
+
16
+
17
+ def _try_import_pipeline():
18
+ """
19
+ Try multiple plausible entrypoints from nl2sql.
20
+ Returns a tuple of callables or None:
21
+ (make_pipeline | None, run_function | None, PipelineClass | None)
22
+ """
23
+ make_pipeline = None
24
+ run_fn = None
25
+ PipelineCls = None
26
+ try:
27
+ from nl2sql.pipeline import make_pipeline as _mk # type: ignore
28
+
29
+ make_pipeline = _mk
30
+ except Exception:
31
+ pass
32
+ try:
33
+ from nl2sql.pipeline import run_nl2sql as _run # type: ignore
34
+
35
+ run_fn = _run
36
+ except Exception:
37
+ pass
38
+ try:
39
+ from nl2sql.pipeline import Pipeline as _P # type: ignore
40
+
41
+ PipelineCls = _P
42
+ except Exception:
43
+ pass
44
+ return make_pipeline, run_fn, PipelineCls
45
+
46
 
47
  LOG_DIR = Path("logs/spider_eval")
48
  LOG_DIR.mkdir(parents=True, exist_ok=True)
49
 
50
+ FORBIDDEN_NODES: Tuple[type, ...] = (
51
+ exp.Insert,
52
+ exp.Delete,
53
+ exp.Update,
54
+ exp.Drop,
55
+ exp.Alter,
56
+ exp.Attach,
57
+ exp.Pragma,
58
+ exp.Create,
59
+ )
60
+
61
 
62
  def normalize_sql(sql: str) -> str:
 
63
  return " ".join(sql.lower().strip().split())
64
 
65
 
66
+ def compare_results(
67
+ pred_rows: Optional[Iterable[Any]], gold_rows: Optional[Iterable[Any]]
68
+ ) -> bool:
69
  if pred_rows is None or gold_rows is None:
70
  return False
 
71
  return set(pred_rows) == set(gold_rows)
72
 
73
 
74
+ def try_execute_sql(
75
+ sql_db: SQLDatabase,
76
+ sql: str,
77
+ timeout: Optional[float] = None, # kept for API compatibility
78
+ ) -> tuple[Optional[list[tuple[Any, ...]]], float, Optional[str]]:
79
  start = time.time()
80
  try:
81
+ raw_rows = sql_db.run(sql)
82
+
83
+ # Normalize result shape for MyPy and downstream code
84
+ if isinstance(raw_rows, list):
85
+ rows = [tuple(r) for r in raw_rows]
86
+ elif isinstance(raw_rows, tuple):
87
+ rows = [tuple(raw_rows)]
88
+ else:
89
+ # Fallback cast — if library returns ResultSet or something similar
90
+ rows = cast(list[tuple[Any, ...]], raw_rows)
91
+
92
  return rows, time.time() - start, None
93
+
94
  except Exception as e:
95
  return None, time.time() - start, str(e)
96
 
 
102
  except Exception:
103
  return False
104
 
105
+ def normalize_ast(node: exp.Expression) -> exp.Expression:
106
  for name, arg in node.args.items():
107
  if isinstance(arg, list):
108
  arg.sort(key=lambda x: str(x))
 
131
  return "UNKNOWN"
132
 
133
 
134
+ def is_safe_sql(sql: str, dialect: Optional[str] = None) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
135
  try:
136
  ast = parse_one(sql, read=dialect)
137
  except ParseError:
 
144
  return True
145
 
146
 
147
+ # --- جایگزین get_schema_preview از app.routers ---
148
+ def get_schema_preview_sqlalchemy(db_path: str, max_cols: int = 0) -> str:
149
+ """
150
+ Lightweight schema preview using SQLAlchemy inspector.
151
+ max_cols=0 => unlimited
152
+ """
153
+ engine = create_engine(f"sqlite:///{db_path}")
154
+ insp = inspect(engine)
155
+ lines: list[str] = []
156
+ for tbl in sorted(insp.get_table_names()):
157
+ cols = insp.get_columns(tbl)
158
+ if max_cols > 0:
159
+ cols = cols[:max_cols]
160
+ col_str = ", ".join(f"{c['name']}:{c.get('type')}" for c in cols)
161
+ pks = insp.get_pk_constraint(tbl).get("constrained_columns") or []
162
+ pk_str = f" | PK: {', '.join(pks)}" if pks else ""
163
+ fks = insp.get_foreign_keys(tbl)
164
+ fk_str = ""
165
+ if fks:
166
+ fks_desc = []
167
+ for fk in fks:
168
+ ref = fk.get("referred_table")
169
+ cols_fk = ", ".join(fk.get("constrained_columns") or [])
170
+ ref_cols = ", ".join(fk.get("referred_columns") or [])
171
+ fks_desc.append(f"{cols_fk} -> {ref}({ref_cols})")
172
+ fk_str = " | FK: " + " ; ".join(fks_desc)
173
+ lines.append(f"{tbl}({col_str}){pk_str}{fk_str}")
174
+ engine.dispose()
175
+ return "\n".join(lines)
176
+
177
+
178
+ def _generate_sql(
179
+ question: str, sql_db: SQLDatabase, schema_text: str, max_output_tokens: int = 1000
180
+ ) -> tuple[str, str, dict[str, Any]]:
181
+ """
182
+ Returns: (status_msg, sql_text, extra_output)
183
+ Strategy:
184
+ 1) If nl2sql.pipeline.run_nl2sql exists: call it.
185
+ 2) Else if nl2sql.pipeline.make_pipeline exists: build and run.
186
+ 3) Else if nl2sql.pipeline.Pipeline exists: instantiate minimal pipeline and run.
187
+ 4) Else: raise NotImplementedError.
188
+ """
189
+ make_pipeline, run_fn, PipelineCls = _try_import_pipeline()
190
+
191
+ # Case 1: direct run function
192
+ if run_fn is not None:
193
+ res = run_fn(
194
+ question=question,
195
+ schema_text=schema_text,
196
+ sql_db=sql_db,
197
+ max_output_tokens=max_output_tokens,
198
+ )
199
+ # Expecting a dict-like or object with attributes; normalize:
200
+ if isinstance(res, dict):
201
+ msg = res.get("status", "ok")
202
+ sql = res.get("sql", "")
203
+ return msg, sql, res
204
+ # fallback generic
205
+ msg = getattr(res, "status", "ok")
206
+ sql = getattr(res, "sql", "")
207
+ return msg, sql, {"result": res}
208
+
209
+ # Case 2: factory + run
210
+ if make_pipeline is not None:
211
+ pipe = make_pipeline(sql_db=sql_db, schema_text=schema_text) # type: ignore[arg-type]
212
+ # Common conventions:
213
+ if hasattr(pipe, "run"):
214
+ out = pipe.run(question) # type: ignore[call-arg]
215
+ elif hasattr(pipe, "execute"):
216
+ out = pipe.execute(question) # type: ignore[call-arg]
217
+ else:
218
+ raise RuntimeError("Pipeline object has no run/execute()")
219
+ msg = getattr(out, "status", "ok")
220
+ sql = getattr(out, "sql", "")
221
+ return msg, sql, {"result": out}
222
+
223
+ # Case 3: class-based pipeline
224
+ if PipelineCls is not None:
225
+ # Try minimal constructor names; adjust to your class signature if needed
226
+ # We pass what we have; extra kwargs should be ignored or have defaults.
227
+ pipe = PipelineCls(sql_db=sql_db, schema_text=schema_text)
228
+ if hasattr(pipe, "run"):
229
+ out = pipe.run(question) # type: ignore[call-arg]
230
+ else:
231
+ raise RuntimeError("Pipeline class has no run()")
232
+ msg = getattr(out, "status", "ok")
233
+ sql = getattr(out, "sql", "")
234
+ return msg, sql, {"result": out}
235
+
236
+ raise NotImplementedError(
237
+ "Cannot locate a public NL2SQL entrypoint in nl2sql.pipeline. "
238
+ "Expose one of: run_nl2sql(), make_pipeline(), or Pipeline.run()."
239
+ )
240
+
241
+
242
+ def run_eval(
243
+ split: str = "dev", limit: int = 100, resume: bool = True, sleep_time: float = 0.01
244
+ ) -> None:
245
  data = load_spider_sqlite(split)
246
  if len(data) < limit:
247
  limit = len(data)
 
256
  results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
257
  metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
258
 
259
+ done: set[tuple[str, str]] = set()
260
  if resume and results_fn.exists():
261
  with results_fn.open("r", encoding="utf-8") as f:
262
  for line in f:
 
269
  pass
270
 
271
  write_header = not results_fn.exists()
272
+ agg: list[dict[str, Any]] = []
273
+
274
  with (
275
  results_fn.open("a", encoding="utf-8") as fout,
276
  pred_txt.open("a", encoding="utf-8") as fpred,
 
286
  fout.write("# " + json.dumps(header, ensure_ascii=False) + "\n")
287
  fout.flush()
288
 
 
289
  for ex in tqdm(data):
290
  key = (ex.db_id, ex.question)
291
  if resume and key in done:
292
  continue
293
 
294
  db_path = str(ex.db_path)
295
+ schema = get_schema_preview_sqlalchemy(db_path, max_cols=0)
296
  sql_db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
 
 
 
 
 
 
 
297
 
298
  t0 = time.time()
299
+ try:
300
+ msg, sql, output = _generate_sql(
301
+ ex.question, sql_db, schema, max_output_tokens=1000
302
+ )
303
+ except NotImplementedError as e:
304
+ rec = {
305
+ "db_id": ex.db_id,
306
+ "question": ex.question,
307
+ "gold_sql": ex.gold_sql,
308
+ "pred_sql": "",
309
+ "status": "no_entrypoint",
310
+ "output": {"error": str(e)},
311
+ "gen_time": time.time() - t0,
312
+ "exec_time": None,
313
+ "error": "no_entrypoint",
314
+ "gold_error": None,
315
+ "pred_rows": None,
316
+ "gold_rows": None,
317
+ "exact_match": False,
318
+ "exact_match_structural": False,
319
+ "execution_accuracy": False,
320
+ "safe_check_failed": True,
321
+ }
322
+ fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
323
+ fout.flush()
324
+ fgold.write(f"{ex.gold_sql}\t{ex.db_id}\n")
325
+ fgold.flush()
326
+ agg.append(rec)
327
+ if sleep_time > 0:
328
+ time.sleep(sleep_time)
329
+ continue
330
+
331
  gen_time = time.time() - t0
332
 
333
  safe_flag = is_safe_sql(sql)
 
365
  gold_rows, gold_time, gold_error = try_execute_sql(sql_db, ex.gold_sql)
366
 
367
  skip = gold_error is not None
368
+ em = normalize_sql(sql) == normalize_sql(ex.gold_sql) if not skip else False
369
+ em_struct = exact_match_structural(sql, ex.gold_sql) if not skip else False
370
+ exec_acc = compare_results(pred_rows, gold_rows) if not skip else False
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  rec = {
373
  "db_id": ex.db_id,
 
387
  "execution_accuracy": exec_acc,
388
  "safe_check_failed": False,
389
  }
 
390
  fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
391
  fout.flush()
392
  fpred.write(f"{sql}\t{ex.db_id}\n")
 
401
  valid = [
402
  r
403
  for r in agg
404
+ if (not r.get("safe_check_failed", False)) and (r.get("gold_error") is None)
405
  ]
406
  total_valid = len(valid)
407
  total_all = len(agg)
 
418
  if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
419
  )
420
  safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
421
+ avg_gen_time = sum(float(r["gen_time"]) for r in valid) / total_valid
422
+ avg_exec_time = sum(float(r["exec_time"]) for r in valid) / total_valid
423
 
424
  metrics = {
425
  "commit_hash": commit_hash,
 
437
  "run_id": start_ts,
438
  }
439
 
440
+ metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
441
  with metrics_fn.open("w", encoding="utf-8") as fm:
442
  json.dump(metrics, fm, ensure_ascii=False, indent=2)
443
 
benchmarks/run.py CHANGED
@@ -1,11 +1,11 @@
1
- # benchmarks/run.py
2
  from __future__ import annotations
 
3
  import argparse
4
  import os
5
  import json
6
  import time
7
  from pathlib import Path
8
- from typing import Iterable, List, Dict, Any
9
 
10
  # ---- app imports
11
  from nl2sql.pipeline import Pipeline, FinalResult
@@ -22,11 +22,34 @@ from adapters.db.sqlite_adapter import SQLiteAdapter
22
  from adapters.llm.openai_provider import OpenAIProvider
23
 
24
 
25
- # ---- fallbacks: Dummy LLM (so it runs without API keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class DummyLLM:
27
  provider_id = "dummy-llm"
28
 
29
- def plan(self, *, user_query: str, schema_preview: str):
30
  text = (
31
  f"- understand question: {user_query}\n"
32
  "- identify tables\n- join if needed\n- filter\n- order/limit"
@@ -39,14 +62,14 @@ class DummyLLM:
39
  user_query: str,
40
  schema_preview: str,
41
  plan_text: str,
42
- clarify_answers=None,
43
- ):
44
  # naive demo SQL (so pipeline flows end-to-end)
45
  sql = "SELECT 1 AS one;"
46
  rationale = "Demo SQL from DummyLLM"
47
  return sql, rationale, 0, 0, 0.0
48
 
49
- def repair(self, *, sql: str, error_msg: str, schema_preview: str):
50
  return sql, 0, 0, 0.0
51
 
52
 
@@ -73,11 +96,12 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
73
  db = SQLiteAdapter(str(db_path))
74
  executor = Executor(db)
75
 
76
- # LLM provider
 
77
  if use_openai and os.getenv("OPENAI_API_KEY"):
78
- llm = OpenAIProvider()
79
  else:
80
- llm = DummyLLM()
81
 
82
  # stages
83
  detector = AmbiguityDetector()
 
 
1
  from __future__ import annotations
2
+
3
  import argparse
4
  import os
5
  import json
6
  import time
7
  from pathlib import Path
8
+ from typing import Iterable, List, Dict, Any, Protocol, Tuple, Optional
9
 
10
  # ---- app imports
11
  from nl2sql.pipeline import Pipeline, FinalResult
 
22
  from adapters.llm.openai_provider import OpenAIProvider
23
 
24
 
25
+ # ---- LLM protocol (unifies OpenAIProvider and DummyLLM for mypy)
26
+ class LLMProvider(Protocol):
27
+ """Minimal interface required by Planner/Generator/Repair stages."""
28
+
29
+ provider_id: str
30
+
31
+ def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
32
+ ...
33
+
34
+ def generate_sql(
35
+ self,
36
+ *,
37
+ user_query: str,
38
+ schema_preview: str,
39
+ plan_text: str,
40
+ clarify_answers: Optional[Any] = None,
41
+ ) -> Tuple[str, str, int, int, float]:
42
+ ...
43
+
44
+ def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
45
+ ...
46
+
47
+
48
+ # ---- fallback: Dummy LLM (so it runs without API keys)
49
  class DummyLLM:
50
  provider_id = "dummy-llm"
51
 
52
+ def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
53
  text = (
54
  f"- understand question: {user_query}\n"
55
  "- identify tables\n- join if needed\n- filter\n- order/limit"
 
62
  user_query: str,
63
  schema_preview: str,
64
  plan_text: str,
65
+ clarify_answers: Optional[Any] = None,
66
+ ) -> Tuple[str, str, int, int, float]:
67
  # naive demo SQL (so pipeline flows end-to-end)
68
  sql = "SELECT 1 AS one;"
69
  rationale = "Demo SQL from DummyLLM"
70
  return sql, rationale, 0, 0, 0.0
71
 
72
+ def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
73
  return sql, 0, 0, 0.0
74
 
75
 
 
96
  db = SQLiteAdapter(str(db_path))
97
  executor = Executor(db)
98
 
99
+ # LLM provider (typed to the Protocol so mypy accepts either provider)
100
+ llm: LLMProvider
101
  if use_openai and os.getenv("OPENAI_API_KEY"):
102
+ llm = OpenAIProvider() # conforms to LLMProvider
103
  else:
104
+ llm = DummyLLM() # conforms to LLMProvider
105
 
106
  # stages
107
  detector = AmbiguityDetector()