Melika Kheirieh commited on
Commit
a337fad
·
1 Parent(s): 1100ebf

build(mypy): fix type errors and add safety guards for None values

Browse files
.github/workflows/ci.yml CHANGED
@@ -49,7 +49,7 @@ jobs:
49
  run: python -m ruff check .
50
 
51
  - name: Type check (mypy)
52
- run: python -m mypy .
53
 
54
  - name: Run tests
55
  run: python -m pytest -q
 
49
  run: python -m ruff check .
50
 
51
  - name: Type check (mypy)
52
+ run: python -m mypy . --ignore-missing-imports --explicit-package-bases
53
 
54
  - name: Run tests
55
  run: python -m pytest -q
adapters/db/postgres_adapter.py CHANGED
@@ -15,34 +15,56 @@ class PostgresAdapter(DBAdapter):
15
  self.dsn = dsn
16
 
17
  def preview_schema(self, limit_per_table: int = 0) -> str:
 
 
 
 
 
18
  with psycopg.connect(self.dsn) as conn:
19
- cur = conn.cursor()
20
- cur.execute("""
21
- SELECT table_name
22
- FROM information_schema.tables
23
- WHERE table_schema = 'public';
24
- """)
25
- tables = [t[0] for t in cur.fetchall()]
26
- lines = []
27
- for t in tables:
28
  cur.execute(
29
  """
30
- SELECT column_name, data_type
31
- FROM information_schema.columns
32
- WHERE table_name = %s;
33
- """,
34
- (t,),
35
  )
36
- cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
37
- lines.append(f"- {t} ({', '.join(cols)})")
38
- return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
41
- if not sql.strip().lower().startswith("select"):
 
 
 
42
  raise ValueError("Only SELECT statements are allowed.")
 
43
  with psycopg.connect(self.dsn) as conn:
44
- cur = conn.cursor()
45
- cur.execute(sql)
46
- rows = cur.fetchall()
47
- cols = [desc[0] for desc in cur.description]
48
- return rows, cols
 
 
15
  self.dsn = dsn
16
 
17
  def preview_schema(self, limit_per_table: int = 0) -> str:
18
+ """
19
+ Return a simple textual preview of tables and their columns in public schema.
20
+ Example line: "- users (id:integer, name:text)"
21
+ """
22
+ lines: List[str] = []
23
  with psycopg.connect(self.dsn) as conn:
24
+ with conn.cursor() as cur:
25
+ # list tables
 
 
 
 
 
 
 
26
  cur.execute(
27
  """
28
+ SELECT table_name
29
+ FROM information_schema.tables
30
+ WHERE table_schema = 'public'
31
+ ORDER BY table_name;
32
+ """
33
  )
34
+ table_rows = cur.fetchall() or []
35
+ tables: List[str] = [t[0] for t in table_rows if t and t[0]]
36
+
37
+ for t in tables:
38
+ # list columns for table t
39
+ cur.execute(
40
+ """
41
+ SELECT column_name, data_type
42
+ FROM information_schema.columns
43
+ WHERE table_schema = 'public' AND table_name = %s
44
+ ORDER BY ordinal_position;
45
+ """,
46
+ (t,),
47
+ )
48
+ col_rows = cur.fetchall() or []
49
+ # guard against None; build "name:type"
50
+ cols: List[str] = [
51
+ f"{c[0]}:{c[1]}" for c in col_rows if c and len(c) >= 2
52
+ ]
53
+ lines.append(f"- {t} ({', '.join(cols)})")
54
+
55
+ return "\n".join(lines)
56
 
57
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
58
+ """
59
+ Execute a read-only SELECT query and return (rows, columns).
60
+ """
61
+ if not sql or not sql.strip().lower().startswith("select"):
62
  raise ValueError("Only SELECT statements are allowed.")
63
+
64
  with psycopg.connect(self.dsn) as conn:
65
+ with conn.cursor() as cur:
66
+ cur.execute(sql)
67
+ rows = cur.fetchall() or []
68
+ desc = cur.description or ()
69
+ cols: List[str] = [d[0] for d in desc if d]
70
+ return rows, cols
app/routers/nl2sql.py CHANGED
@@ -13,11 +13,13 @@ from nl2sql.repair import Repair
13
  from adapters.db.sqlite_adapter import SQLiteAdapter
14
  from adapters.db.postgres_adapter import PostgresAdapter
15
  import os
 
16
 
17
 
18
  router = APIRouter(prefix="/nl2sql")
19
 
20
 
 
21
  if os.getenv("DB_MODE", "sqlite") == "postgres":
22
  _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
23
  else:
 
13
  from adapters.db.sqlite_adapter import SQLiteAdapter
14
  from adapters.db.postgres_adapter import PostgresAdapter
15
  import os
16
+ from typing import Union
17
 
18
 
19
  router = APIRouter(prefix="/nl2sql")
20
 
21
 
22
+ _db: Union[PostgresAdapter, SQLiteAdapter]
23
  if os.getenv("DB_MODE", "sqlite") == "postgres":
24
  _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
25
  else:
app/schemas.py CHANGED
@@ -1,5 +1,5 @@
1
  from pydantic import BaseModel
2
- from typing import List, Optional, Any, Dict
3
 
4
 
5
  class NL2SQLRequest(BaseModel):
@@ -19,9 +19,9 @@ class TraceModel(BaseModel):
19
 
20
  class NL2SQLResponse(BaseModel):
21
  ambiguous: bool = False
22
- sql: str
23
  rationale: Optional[str] = None
24
- traces: List[TraceModel] = []
25
 
26
 
27
  class ClarifyResponse(BaseModel):
 
1
  from pydantic import BaseModel
2
+ from typing import List, Optional, Any, Dict, Union
3
 
4
 
5
  class NL2SQLRequest(BaseModel):
 
19
 
20
  class NL2SQLResponse(BaseModel):
21
  ambiguous: bool = False
22
+ sql: Optional[str] = None
23
  rationale: Optional[str] = None
24
+ traces: List[Union[TraceModel, dict]] = []
25
 
26
 
27
  class ClarifyResponse(BaseModel):
benchmarks/run.py CHANGED
@@ -5,9 +5,10 @@ import os
5
  import json
6
  import time
7
  from pathlib import Path
 
8
 
9
  # ---- app imports
10
- from nl2sql.pipeline import Pipeline
11
  from nl2sql.ambiguity_detector import AmbiguityDetector
12
  from nl2sql.planner import Planner
13
  from nl2sql.generator import Generator
@@ -26,7 +27,10 @@ class DummyLLM:
26
  provider_id = "dummy-llm"
27
 
28
  def plan(self, *, user_query: str, schema_preview: str):
29
- text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
 
 
 
30
  return text, 0, 0, 0.0
31
 
32
  def generate_sql(
@@ -68,11 +72,13 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
68
  # DB adapter
69
  db = SQLiteAdapter(str(db_path))
70
  executor = Executor(db)
 
71
  # LLM provider
72
  if use_openai and os.getenv("OPENAI_API_KEY"):
73
  llm = OpenAIProvider()
74
  else:
75
  llm = DummyLLM()
 
76
  # stages
77
  detector = AmbiguityDetector()
78
  planner = Planner(llm)
@@ -80,6 +86,7 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
80
  safety = Safety()
81
  verifier = Verifier()
82
  repair = Repair(llm)
 
83
  # pipeline
84
  return Pipeline(
85
  detector=detector,
@@ -92,33 +99,49 @@ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
92
  )
93
 
94
 
95
- def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
96
- results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  for q in queries:
98
  t0 = time.perf_counter()
99
- r = pipeline.run(user_query=q, schema_preview=schema_preview)
100
- latency_ms = (time.perf_counter() - t0) * 1000
101
- ok = (not r.get("ambiguous")) and ("error" not in r)
102
-
103
- traces = r.get("traces", [])
104
- cost_sum = 0.0
105
- for t in traces:
106
- try:
107
- cost_sum += float(t.get("cost_usd", 0.0))
108
- except Exception:
109
- pass
110
 
111
  results.append(
112
  {
113
  "query": q,
114
  "exec_acc": 1.0 if ok else 0.0,
115
- "safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
116
  "latency_ms": latency_ms,
117
  "cost_usd": cost_sum,
118
  "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
119
- "provider": pipeline.generator.llm.provider_id
120
- if hasattr(pipeline.generator, "llm")
121
- else "unknown",
122
  }
123
  )
124
 
@@ -129,7 +152,7 @@ def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
129
  print(f"[OK] wrote {len(results)} rows → {outfile}")
130
 
131
 
132
- def main():
133
  parser = argparse.ArgumentParser()
134
  parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
135
  parser.add_argument("--db", default="data/bench_demo.db")
@@ -140,9 +163,9 @@ def main():
140
  )
141
  args = parser.parse_args()
142
 
143
- ROOT = Path(__file__).resolve().parents[1] # project root
144
- outfile = (ROOT / args.outfile).resolve()
145
- db_path = (ROOT / args.db).resolve()
146
 
147
  ensure_demo_db(db_path)
148
  pipe = build_pipeline(db_path, use_openai=args.use_openai)
 
5
  import json
6
  import time
7
  from pathlib import Path
8
+ from typing import Iterable, List, Dict, Any
9
 
10
  # ---- app imports
11
+ from nl2sql.pipeline import Pipeline, FinalResult
12
  from nl2sql.ambiguity_detector import AmbiguityDetector
13
  from nl2sql.planner import Planner
14
  from nl2sql.generator import Generator
 
27
  provider_id = "dummy-llm"
28
 
29
  def plan(self, *, user_query: str, schema_preview: str):
30
+ text = (
31
+ f"- understand question: {user_query}\n"
32
+ "- identify tables\n- join if needed\n- filter\n- order/limit"
33
+ )
34
  return text, 0, 0, 0.0
35
 
36
  def generate_sql(
 
72
  # DB adapter
73
  db = SQLiteAdapter(str(db_path))
74
  executor = Executor(db)
75
+
76
  # LLM provider
77
  if use_openai and os.getenv("OPENAI_API_KEY"):
78
  llm = OpenAIProvider()
79
  else:
80
  llm = DummyLLM()
81
+
82
  # stages
83
  detector = AmbiguityDetector()
84
  planner = Planner(llm)
 
86
  safety = Safety()
87
  verifier = Verifier()
88
  repair = Repair(llm)
89
+
90
  # pipeline
91
  return Pipeline(
92
  detector=detector,
 
99
  )
100
 
101
 
102
+ def _sum_cost(traces: Iterable[Dict[str, Any]]) -> float:
103
+ total = 0.0
104
+ for tr in traces:
105
+ try:
106
+ total += float(tr.get("cost_usd", 0.0))
107
+ except Exception:
108
+ # ignore bad values
109
+ pass
110
+ return total
111
+
112
+
113
+ def _is_safe_fail(ok: bool, details: List[str] | None) -> float:
114
+ """Return 1.0 when pipeline failed due to unsafe SQL (heuristic)."""
115
+ if ok:
116
+ return 0.0
117
+ txt = " ".join(details or []).lower()
118
+ return 1.0 if "unsafe" in txt else 0.0
119
+
120
+
121
+ def run_benchmark(
122
+ queries: List[str], schema_preview: str, pipeline: Pipeline, outfile: Path
123
+ ) -> None:
124
+ results: List[Dict[str, Any]] = []
125
  for q in queries:
126
  t0 = time.perf_counter()
127
+ res: FinalResult = pipeline.run(user_query=q, schema_preview=schema_preview)
128
+ latency_ms = (time.perf_counter() - t0) * 1000.0
129
+
130
+ ok = (not res.ambiguous) and (not res.error) and bool(res.ok)
131
+ traces = res.traces or []
132
+ cost_sum = _sum_cost(traces)
 
 
 
 
 
133
 
134
  results.append(
135
  {
136
  "query": q,
137
  "exec_acc": 1.0 if ok else 0.0,
138
+ "safe_fail": _is_safe_fail(ok, res.details),
139
  "latency_ms": latency_ms,
140
  "cost_usd": cost_sum,
141
  "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
142
+ "provider": getattr(
143
+ getattr(pipeline.generator, "llm", None), "provider_id", "unknown"
144
+ ),
145
  }
146
  )
147
 
 
152
  print(f"[OK] wrote {len(results)} rows → {outfile}")
153
 
154
 
155
+ def main() -> None:
156
  parser = argparse.ArgumentParser()
157
  parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
158
  parser.add_argument("--db", default="data/bench_demo.db")
 
163
  )
164
  args = parser.parse_args()
165
 
166
+ root = Path(__file__).resolve().parents[1] # project root
167
+ outfile = (root / args.outfile).resolve()
168
+ db_path = (root / args.db).resolve()
169
 
170
  ensure_demo_db(db_path)
171
  pipe = build_pipeline(db_path, use_openai=args.use_openai)
nl2sql/pipeline.py CHANGED
@@ -14,7 +14,6 @@ from nl2sql.repair import Repair
14
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
15
 
16
 
17
- # ---- NEW: FinalResult as domain-level, type-safe result ----
18
  @dataclass(frozen=True)
19
  class FinalResult:
20
  ok: bool
@@ -144,7 +143,7 @@ class Pipeline:
144
  self.generator.run,
145
  user_query=user_query,
146
  schema_preview=schema_preview,
147
- plan_text=r_plan.data.get("plan"),
148
  clarify_answers=clarify_answers or {},
149
  )
150
  traces.extend(self._trace_list(r_gen))
@@ -160,11 +159,10 @@ class Pipeline:
160
  verified=None,
161
  traces=traces,
162
  )
163
- sql = r_gen.data.get("sql")
164
- rationale = r_gen.data.get("rationale")
165
 
166
  # --- 4) safety
167
- # fix: align with DummySafety signature → use .run (not .check)
168
  r_safe = self._safe_stage(self.safety.run, sql=sql)
169
  traces.extend(self._trace_list(r_safe))
170
  if not r_safe.ok:
@@ -181,13 +179,17 @@ class Pipeline:
181
  )
182
 
183
  # --- 5) executor
184
- r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
 
 
185
  traces.extend(self._trace_list(r_exec))
186
  if not r_exec.ok:
187
  details.extend(r_exec.error or [])
188
 
189
  # --- 6) verifier
190
- r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
 
 
191
  traces.extend(self._trace_list(r_ver))
192
  verified = bool(r_ver.ok)
193
 
@@ -203,7 +205,7 @@ class Pipeline:
203
  traces.extend(self._trace_list(r_fix))
204
  if not r_fix.ok:
205
  break
206
- sql = r_fix.data.get("sql")
207
 
208
  r_safe = self._safe_stage(self.safety.run, sql=sql)
209
  traces.extend(self._trace_list(r_safe))
@@ -212,7 +214,7 @@ class Pipeline:
212
  continue
213
 
214
  r_exec = self._safe_stage(
215
- self.executor.run, sql=r_safe.data.get("sql", sql)
216
  )
217
  traces.extend(self._trace_list(r_exec))
218
  if not r_exec.ok:
@@ -220,7 +222,7 @@ class Pipeline:
220
  continue
221
 
222
  r_ver = self._safe_stage(
223
- self.verifier.run, sql=sql, exec_result=r_exec.data
224
  )
225
  traces.extend(self._trace_list(r_ver))
226
  verified = bool(r_ver.ok)
 
14
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
15
 
16
 
 
17
  @dataclass(frozen=True)
18
  class FinalResult:
19
  ok: bool
 
143
  self.generator.run,
144
  user_query=user_query,
145
  schema_preview=schema_preview,
146
+ plan_text=(r_plan.data or {}).get("plan"),
147
  clarify_answers=clarify_answers or {},
148
  )
149
  traces.extend(self._trace_list(r_gen))
 
159
  verified=None,
160
  traces=traces,
161
  )
162
+ sql = (r_gen.data or {}).get("sql")
163
+ rationale = (r_gen.data or {}).get("rationale")
164
 
165
  # --- 4) safety
 
166
  r_safe = self._safe_stage(self.safety.run, sql=sql)
167
  traces.extend(self._trace_list(r_safe))
168
  if not r_safe.ok:
 
179
  )
180
 
181
  # --- 5) executor
182
+ r_exec = self._safe_stage(
183
+ self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
184
+ )
185
  traces.extend(self._trace_list(r_exec))
186
  if not r_exec.ok:
187
  details.extend(r_exec.error or [])
188
 
189
  # --- 6) verifier
190
+ r_ver = self._safe_stage(
191
+ self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
192
+ )
193
  traces.extend(self._trace_list(r_ver))
194
  verified = bool(r_ver.ok)
195
 
 
205
  traces.extend(self._trace_list(r_fix))
206
  if not r_fix.ok:
207
  break
208
+ sql = (r_fix.data or {}).get("sql")
209
 
210
  r_safe = self._safe_stage(self.safety.run, sql=sql)
211
  traces.extend(self._trace_list(r_safe))
 
214
  continue
215
 
216
  r_exec = self._safe_stage(
217
+ self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
218
  )
219
  traces.extend(self._trace_list(r_exec))
220
  if not r_exec.ok:
 
222
  continue
223
 
224
  r_ver = self._safe_stage(
225
+ self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
226
  )
227
  traces.extend(self._trace_list(r_ver))
228
  verified = bool(r_ver.ok)
nl2sql/safety.py CHANGED
@@ -86,3 +86,5 @@ class Safety:
86
  stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
87
  ),
88
  )
 
 
 
86
  stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
87
  ),
88
  )
89
+
90
+ run = check