Melika Kheirieh commited on
Commit
c1bc4eb
·
1 Parent(s): 646d80b

style: format code with ruff

Browse files
adapters/db/base.py CHANGED
@@ -1,8 +1,10 @@
1
  from typing import Tuple, List, Dict, Any, Protocol
2
  from typing import List, Tuple, Any
3
 
 
4
  class DBAdapter(Protocol):
5
  """Abstract database adapter for read-only queries."""
 
6
  name: str
7
  dialect: str
8
 
@@ -10,4 +12,4 @@ class DBAdapter(Protocol):
10
  """Generate a readable summary of the database schema with optional sample rows per table."""
11
 
12
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
13
- """Execute a SELECT query and return (rows, columns)."""
 
1
  from typing import Tuple, List, Dict, Any, Protocol
2
  from typing import List, Tuple, Any
3
 
4
+
5
  class DBAdapter(Protocol):
6
  """Abstract database adapter for read-only queries."""
7
+
8
  name: str
9
  dialect: str
10
 
 
12
  """Generate a readable summary of the database schema with optional sample rows per table."""
13
 
14
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
15
+ """Execute a SELECT query and return (rows, columns)."""
adapters/db/postgres_adapter.py CHANGED
@@ -2,6 +2,7 @@ import psycopg
2
  from typing import Any, List, Tuple
3
  from adapters.db.base import DBAdapter
4
 
 
5
  class PostgresAdapter(DBAdapter):
6
  name = "postgres"
7
  dialect = "postgres"
@@ -24,11 +25,14 @@ class PostgresAdapter(DBAdapter):
24
  tables = [t[0] for t in cur.fetchall()]
25
  lines = []
26
  for t in tables:
27
- cur.execute(f"""
 
28
  SELECT column_name, data_type
29
  FROM information_schema.columns
30
  WHERE table_name = %s;
31
- """, (t,))
 
 
32
  cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
33
  lines.append(f"- {t} ({', '.join(cols)})")
34
  return "\n".join(lines)
 
2
  from typing import Any, List, Tuple
3
  from adapters.db.base import DBAdapter
4
 
5
+
6
  class PostgresAdapter(DBAdapter):
7
  name = "postgres"
8
  dialect = "postgres"
 
25
  tables = [t[0] for t in cur.fetchall()]
26
  lines = []
27
  for t in tables:
28
+ cur.execute(
29
+ f"""
30
  SELECT column_name, data_type
31
  FROM information_schema.columns
32
  WHERE table_name = %s;
33
+ """,
34
+ (t,),
35
+ )
36
  cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
37
  lines.append(f"- {t} ({', '.join(cols)})")
38
  return "\n".join(lines)
adapters/db/sqlite_adapter.py CHANGED
@@ -2,6 +2,7 @@ import sqlite3
2
  from typing import List, Tuple, Any
3
  from adapters.db.base import DBAdapter
4
 
 
5
  class SQLiteAdapter(DBAdapter):
6
  name = "sqlite"
7
  dialect = "sqlite"
 
2
  from typing import List, Tuple, Any
3
  from adapters.db.base import DBAdapter
4
 
5
+
6
  class SQLiteAdapter(DBAdapter):
7
  name = "sqlite"
8
  dialect = "sqlite"
adapters/llm/base.py CHANGED
@@ -2,15 +2,26 @@
2
  from __future__ import annotations
3
  from typing import Tuple, List, Dict, Any, Protocol
4
 
 
5
  class LLMProvider(Protocol):
6
  provider_id: str
7
 
8
- def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
 
 
9
  """Return (plan_text, token_in, token_out, cost_usd)."""
10
 
11
- def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str,
12
- clarify_answers: Dict[str, Any] | None = None) -> Tuple[str, str, int, int, float]:
 
 
 
 
 
 
13
  """Return (sql, rationale, token_in, token_out, cost_usd)."""
14
 
15
- def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
 
 
16
  """Return (patched_sql, token_in, token_out, cost_usd)."""
 
2
  from __future__ import annotations
3
  from typing import Tuple, List, Dict, Any, Protocol
4
 
5
+
6
  class LLMProvider(Protocol):
7
  provider_id: str
8
 
9
+ def plan(
10
+ self, *, user_query: str, schema_preview: str
11
+ ) -> Tuple[str, int, int, float]:
12
  """Return (plan_text, token_in, token_out, cost_usd)."""
13
 
14
+ def generate_sql(
15
+ self,
16
+ *,
17
+ user_query: str,
18
+ schema_preview: str,
19
+ plan_text: str,
20
+ clarify_answers: Dict[str, Any] | None = None,
21
+ ) -> Tuple[str, str, int, int, float]:
22
  """Return (sql, rationale, token_in, token_out, cost_usd)."""
23
 
24
+ def repair(
25
+ self, *, sql: str, error_msg: str, schema_preview: str
26
+ ) -> Tuple[str, int, int, float]:
27
  """Return (patched_sql, token_in, token_out, cost_usd)."""
adapters/llm/openai_provider.py CHANGED
@@ -11,14 +11,13 @@ from openai import OpenAI
11
  # - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
12
 
13
 
14
-
15
  class OpenAIProvider(LLMProvider):
16
  provider_id = "openai"
17
 
18
  def __init__(self) -> None:
19
  self.client = OpenAI(
20
  api_key=os.environ["OPENAI_API_KEY"],
21
- base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
22
  )
23
  self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
24
 
@@ -27,16 +26,25 @@ class OpenAIProvider(LLMProvider):
27
  model=self.model,
28
  messages=[
29
  {"role": "system", "content": "You create SQL query plans."},
30
- {"role": "user", "content": f"Query: {user_query}\nSchema:\n{schema_preview}"}
 
 
 
31
  ],
32
- temperature=0
33
  )
34
  msg = completion.choices[0].message.content
35
  usage = completion.usage
36
- return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
37
-
 
 
 
 
38
 
39
- def generate_sql(self, *, user_query, schema_preview, plan_text, clarify_answers=None):
 
 
40
  prompt = f"""
41
  You are a precise SQL generator.
42
  Return ONLY valid JSON with two keys: "sql" and "rationale".
@@ -60,9 +68,9 @@ class OpenAIProvider(LLMProvider):
60
  model=self.model,
61
  messages=[
62
  {"role": "system", "content": "You convert natural language to SQL."},
63
- {"role": "user", "content": prompt}
64
  ],
65
- temperature=0
66
  )
67
  content = completion.choices[0].message.content.strip()
68
  usage = completion.usage # ← لازم داریم
@@ -78,7 +86,7 @@ class OpenAIProvider(LLMProvider):
78
  end = content.rfind("}")
79
  if start != -1 and end != -1:
80
  try:
81
- parsed = json.loads(content[start:end + 1])
82
  except Exception:
83
  raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
84
  else:
@@ -93,19 +101,29 @@ class OpenAIProvider(LLMProvider):
93
  # IMPORTANT: return the expected 5-tuple
94
  return sql, rationale, t_in, t_out, cost
95
 
96
-
97
  def repair(self, *, sql, error_msg, schema_preview):
98
  completion = self.client.chat.completions.create(
99
  model=self.model,
100
  messages=[
101
- {"role": "system", "content": "You fix SQL queries keeping them SELECT-only."},
102
- {"role": "user", "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}"}
 
 
 
 
 
 
103
  ],
104
- temperature=0
105
  )
106
  msg = completion.choices[0].message.content
107
  usage = completion.usage
108
- return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
 
 
 
 
 
109
 
110
  def _estimate_cost(self, usage):
111
  # Rough estimation example — can be refined with official token pricing
 
11
  # - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
12
 
13
 
 
14
  class OpenAIProvider(LLMProvider):
15
  provider_id = "openai"
16
 
17
  def __init__(self) -> None:
18
  self.client = OpenAI(
19
  api_key=os.environ["OPENAI_API_KEY"],
20
+ base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
21
  )
22
  self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
23
 
 
26
  model=self.model,
27
  messages=[
28
  {"role": "system", "content": "You create SQL query plans."},
29
+ {
30
+ "role": "user",
31
+ "content": f"Query: {user_query}\nSchema:\n{schema_preview}",
32
+ },
33
  ],
34
+ temperature=0,
35
  )
36
  msg = completion.choices[0].message.content
37
  usage = completion.usage
38
+ return (
39
+ msg,
40
+ usage.prompt_tokens,
41
+ usage.completion_tokens,
42
+ self._estimate_cost(usage),
43
+ )
44
 
45
+ def generate_sql(
46
+ self, *, user_query, schema_preview, plan_text, clarify_answers=None
47
+ ):
48
  prompt = f"""
49
  You are a precise SQL generator.
50
  Return ONLY valid JSON with two keys: "sql" and "rationale".
 
68
  model=self.model,
69
  messages=[
70
  {"role": "system", "content": "You convert natural language to SQL."},
71
+ {"role": "user", "content": prompt},
72
  ],
73
+ temperature=0,
74
  )
75
  content = completion.choices[0].message.content.strip()
76
  usage = completion.usage # ← لازم داریم
 
86
  end = content.rfind("}")
87
  if start != -1 and end != -1:
88
  try:
89
+ parsed = json.loads(content[start : end + 1])
90
  except Exception:
91
  raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
92
  else:
 
101
  # IMPORTANT: return the expected 5-tuple
102
  return sql, rationale, t_in, t_out, cost
103
 
 
104
  def repair(self, *, sql, error_msg, schema_preview):
105
  completion = self.client.chat.completions.create(
106
  model=self.model,
107
  messages=[
108
+ {
109
+ "role": "system",
110
+ "content": "You fix SQL queries keeping them SELECT-only.",
111
+ },
112
+ {
113
+ "role": "user",
114
+ "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}",
115
+ },
116
  ],
117
+ temperature=0,
118
  )
119
  msg = completion.choices[0].message.content
120
  usage = completion.usage
121
+ return (
122
+ msg,
123
+ usage.prompt_tokens,
124
+ usage.completion_tokens,
125
+ self._estimate_cost(usage),
126
+ )
127
 
128
  def _estimate_cost(self, usage):
129
  # Rough estimation example — can be refined with official token pricing
app/main.py CHANGED
@@ -1,29 +1,29 @@
1
  from dotenv import load_dotenv
 
2
  load_dotenv()
3
 
4
  from fastapi import FastAPI
5
  from app.routers import nl2sql
 
6
  app = FastAPI(
7
  title="NL2SQL Copilot Prototype",
8
  version="0.1.0",
9
- description="Natural Language -> SQL Copilot API"
10
  )
11
 
12
  app.include_router(nl2sql.router, prefix="/api/v1")
13
 
 
14
  @app.get("/healthz")
15
  def health_check():
16
  return {"status": "ok"}
17
 
 
18
  @app.get("/")
19
  def root():
20
  return {"status": "ok", "message": "NL2SQL Copilot API is running"}
21
 
 
22
  @app.get("/health")
23
  def health():
24
- return {
25
- "status": "ok",
26
- "db": "connected",
27
- "llm": "reachable",
28
- "uptime_sec": 123.4
29
- }
 
1
  from dotenv import load_dotenv
2
+
3
  load_dotenv()
4
 
5
  from fastapi import FastAPI
6
  from app.routers import nl2sql
7
+
8
  app = FastAPI(
9
  title="NL2SQL Copilot Prototype",
10
  version="0.1.0",
11
+ description="Natural Language -> SQL Copilot API",
12
  )
13
 
14
  app.include_router(nl2sql.router, prefix="/api/v1")
15
 
16
+
17
  @app.get("/healthz")
18
  def health_check():
19
  return {"status": "ok"}
20
 
21
+
22
  @app.get("/")
23
  def root():
24
  return {"status": "ok", "message": "NL2SQL Copilot API is running"}
25
 
26
+
27
  @app.get("/health")
28
  def health():
29
+ return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
 
 
 
 
 
app/routers/nl2sql.py CHANGED
@@ -19,7 +19,6 @@ import os
19
  router = APIRouter(prefix="/nl2sql")
20
 
21
 
22
-
23
  if os.getenv("DB_MODE", "sqlite") == "postgres":
24
  _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
25
  else:
@@ -40,7 +39,7 @@ _pipeline = Pipeline(
40
  safety=Safety(),
41
  executor=_executor,
42
  verifier=_verifier,
43
- repair=_repair
44
  )
45
 
46
 
@@ -48,6 +47,7 @@ def _to_dict(obj):
48
  """Helper: safely convert dataclass → dict."""
49
  return asdict(obj) if is_dataclass(obj) else obj
50
 
 
51
  def _round_trace(t: dict) -> dict:
52
  if t.get("cost_usd") is not None:
53
  t["cost_usd"] = round(t["cost_usd"], 6)
@@ -55,9 +55,12 @@ def _round_trace(t: dict) -> dict:
55
  t["duration_ms"] = round(t["duration_ms"], 2)
56
  return t
57
 
 
58
  @router.post("", name="nl2sql_handler")
59
  def nl2sql_handler(request: NL2SQLRequest):
60
- result = _pipeline.run(user_query=request.query, schema_preview=request.schema_preview)
 
 
61
 
62
  # --- Ensure result type ---
63
  if not isinstance(result, StageResult):
 
19
  router = APIRouter(prefix="/nl2sql")
20
 
21
 
 
22
  if os.getenv("DB_MODE", "sqlite") == "postgres":
23
  _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
24
  else:
 
39
  safety=Safety(),
40
  executor=_executor,
41
  verifier=_verifier,
42
+ repair=_repair,
43
  )
44
 
45
 
 
47
  """Helper: safely convert dataclass → dict."""
48
  return asdict(obj) if is_dataclass(obj) else obj
49
 
50
+
51
  def _round_trace(t: dict) -> dict:
52
  if t.get("cost_usd") is not None:
53
  t["cost_usd"] = round(t["cost_usd"], 6)
 
55
  t["duration_ms"] = round(t["duration_ms"], 2)
56
  return t
57
 
58
+
59
  @router.post("", name="nl2sql_handler")
60
  def nl2sql_handler(request: NL2SQLRequest):
61
+ result = _pipeline.run(
62
+ user_query=request.query, schema_preview=request.schema_preview
63
+ )
64
 
65
  # --- Ensure result type ---
66
  if not isinstance(result, StageResult):
app/schemas.py CHANGED
@@ -1,11 +1,13 @@
1
  from pydantic import BaseModel
2
  from typing import List, Optional, Any, Dict
3
 
 
4
  class NL2SQLRequest(BaseModel):
5
  query: str
6
  schema_preview: str
7
  db_name: Optional[str] = "default"
8
 
 
9
  class TraceModel(BaseModel):
10
  stage: str
11
  duration_ms: float
@@ -14,16 +16,19 @@ class TraceModel(BaseModel):
14
  cost_usd: float | None = 0
15
  notes: Dict[str, Any] | None = None
16
 
 
17
  class NL2SQLResponse(BaseModel):
18
  ambiguous: bool = False
19
  sql: str
20
  rationale: Optional[str] = None
21
  traces: List[TraceModel] = []
22
 
 
23
  class ClarifyResponse(BaseModel):
24
  ambiguous: bool = True
25
  questions: List[str]
26
 
 
27
  class ErrorResponse(BaseModel):
28
  error: str
29
  details: List[str] | None = None
 
1
  from pydantic import BaseModel
2
  from typing import List, Optional, Any, Dict
3
 
4
+
5
  class NL2SQLRequest(BaseModel):
6
  query: str
7
  schema_preview: str
8
  db_name: Optional[str] = "default"
9
 
10
+
11
  class TraceModel(BaseModel):
12
  stage: str
13
  duration_ms: float
 
16
  cost_usd: float | None = 0
17
  notes: Dict[str, Any] | None = None
18
 
19
+
20
  class NL2SQLResponse(BaseModel):
21
  ambiguous: bool = False
22
  sql: str
23
  rationale: Optional[str] = None
24
  traces: List[TraceModel] = []
25
 
26
+
27
  class ClarifyResponse(BaseModel):
28
  ambiguous: bool = True
29
  questions: List[str]
30
 
31
+
32
  class ErrorResponse(BaseModel):
33
  error: str
34
  details: List[str] | None = None
benchmarks/evaluate_spider.py CHANGED
@@ -13,16 +13,19 @@ from sqlglot.errors import ParseError
13
  LOG_DIR = Path("logs/spider_eval")
14
  LOG_DIR.mkdir(parents=True, exist_ok=True)
15
 
 
16
  def normalize_sql(sql: str) -> str:
17
  # نسخه ساده؛ می‌تونی قوی‌ترش کنی با پارس + بازسازی
18
  return " ".join(sql.lower().strip().split())
19
 
 
20
  def compare_results(pred_rows, gold_rows):
21
  if pred_rows is None or gold_rows is None:
22
  return False
23
  # اگر ترتیب مهم نیست
24
  return set(pred_rows) == set(gold_rows)
25
 
 
26
  def try_execute_sql(sql_db, sql, timeout: float = None):
27
  start = time.time()
28
  try:
@@ -31,6 +34,7 @@ def try_execute_sql(sql_db, sql, timeout: float = None):
31
  except Exception as e:
32
  return None, time.time() - start, str(e)
33
 
 
34
  def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
35
  try:
36
  ast_pred = parse_one(sql_pred)
@@ -54,13 +58,19 @@ def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
54
  norm_gold = normalize_ast(ast_gold)
55
  return norm_prd == norm_gold
56
 
 
57
  def get_git_commit_hash() -> str:
58
  try:
59
- out = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("ascii")
 
 
 
 
60
  return out
61
  except Exception:
62
  return "UNKNOWN"
63
 
 
64
  FORBIDDEN_NODES = (
65
  exp.Insert,
66
  exp.Delete,
@@ -72,6 +82,7 @@ FORBIDDEN_NODES = (
72
  exp.Create,
73
  )
74
 
 
75
  def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
76
  try:
77
  ast = parse_one(sql, read=dialect)
@@ -84,6 +95,7 @@ def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
84
  return False
85
  return True
86
 
 
87
  def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
88
  data = load_spider_sqlite(split)
89
  if len(data) < limit:
@@ -94,8 +106,8 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
94
  commit_hash = get_git_commit_hash()
95
  start_ts = int(time.time())
96
 
97
- pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
98
- gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
99
  results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
100
  metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
101
 
@@ -112,10 +124,11 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
112
  pass
113
 
114
  write_header = not results_fn.exists()
115
- with results_fn.open("a", encoding="utf-8") as fout, \
116
- pred_txt.open("a", encoding="utf-8") as fpred, \
117
- gold_txt.open("a", encoding="utf-8") as fgold:
118
-
 
119
  if write_header:
120
  header = {
121
  "commit_hash": commit_hash,
@@ -228,21 +241,28 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
228
  if sleep_time > 0:
229
  time.sleep(sleep_time)
230
 
231
-
232
- valid = [r for r in agg if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None]
 
 
 
233
  total_valid = len(valid)
234
  total_all = len(agg)
235
  if total_valid == 0:
236
  print("No valid examples to compute metrics")
237
  return
238
 
239
- em_count = sum(1 for r in valid if r["exact_match"])
240
  em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
241
- exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
242
- error_count = sum(1 for r in agg if (r.get("error") is not None) and (not r.get("safe_check_failed", False)))
 
 
 
 
243
  safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
244
- avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
245
- avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
246
 
247
  metrics = {
248
  "commit_hash": commit_hash,
 
13
  LOG_DIR = Path("logs/spider_eval")
14
  LOG_DIR.mkdir(parents=True, exist_ok=True)
15
 
16
+
17
  def normalize_sql(sql: str) -> str:
18
  # نسخه ساده؛ می‌تونی قوی‌ترش کنی با پارس + بازسازی
19
  return " ".join(sql.lower().strip().split())
20
 
21
+
22
  def compare_results(pred_rows, gold_rows):
23
  if pred_rows is None or gold_rows is None:
24
  return False
25
  # اگر ترتیب مهم نیست
26
  return set(pred_rows) == set(gold_rows)
27
 
28
+
29
  def try_execute_sql(sql_db, sql, timeout: float = None):
30
  start = time.time()
31
  try:
 
34
  except Exception as e:
35
  return None, time.time() - start, str(e)
36
 
37
+
38
  def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
39
  try:
40
  ast_pred = parse_one(sql_pred)
 
58
  norm_gold = normalize_ast(ast_gold)
59
  return norm_prd == norm_gold
60
 
61
+
62
  def get_git_commit_hash() -> str:
63
  try:
64
+ out = (
65
+ subprocess.check_output(["git", "rev-parse", "HEAD"])
66
+ .strip()
67
+ .decode("ascii")
68
+ )
69
  return out
70
  except Exception:
71
  return "UNKNOWN"
72
 
73
+
74
  FORBIDDEN_NODES = (
75
  exp.Insert,
76
  exp.Delete,
 
82
  exp.Create,
83
  )
84
 
85
+
86
  def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
87
  try:
88
  ast = parse_one(sql, read=dialect)
 
95
  return False
96
  return True
97
 
98
+
99
  def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
100
  data = load_spider_sqlite(split)
101
  if len(data) < limit:
 
106
  commit_hash = get_git_commit_hash()
107
  start_ts = int(time.time())
108
 
109
+ pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
110
+ gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
111
  results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
112
  metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
113
 
 
124
  pass
125
 
126
  write_header = not results_fn.exists()
127
+ with (
128
+ results_fn.open("a", encoding="utf-8") as fout,
129
+ pred_txt.open("a", encoding="utf-8") as fpred,
130
+ gold_txt.open("a", encoding="utf-8") as fgold,
131
+ ):
132
  if write_header:
133
  header = {
134
  "commit_hash": commit_hash,
 
241
  if sleep_time > 0:
242
  time.sleep(sleep_time)
243
 
244
+ valid = [
245
+ r
246
+ for r in agg
247
+ if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None
248
+ ]
249
  total_valid = len(valid)
250
  total_all = len(agg)
251
  if total_valid == 0:
252
  print("No valid examples to compute metrics")
253
  return
254
 
255
+ em_count = sum(1 for r in valid if r["exact_match"])
256
  em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
257
+ exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
258
+ error_count = sum(
259
+ 1
260
+ for r in agg
261
+ if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
262
+ )
263
  safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
264
+ avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
265
+ avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
266
 
267
  metrics = {
268
  "commit_hash": commit_hash,
benchmarks/run.py CHANGED
@@ -20,6 +20,7 @@ from nl2sql.repair import Repair
20
  from adapters.db.sqlite_adapter import SQLiteAdapter
21
  from adapters.llm.openai_provider import OpenAIProvider
22
 
 
23
  # ---- fallbacks: Dummy LLM (so it runs without API keys)
24
  class DummyLLM:
25
  provider_id = "dummy-llm"
@@ -28,7 +29,14 @@ class DummyLLM:
28
  text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
29
  return text, 0, 0, 0.0
30
 
31
- def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str, clarify_answers=None):
 
 
 
 
 
 
 
32
  # naive demo SQL (so pipeline flows end-to-end)
33
  sql = "SELECT 1 AS one;"
34
  rationale = "Demo SQL from DummyLLM"
@@ -43,12 +51,15 @@ def ensure_demo_db(path: Path) -> None:
43
  if path.exists():
44
  return
45
  import sqlite3
 
46
  path.parent.mkdir(parents=True, exist_ok=True)
47
  con = sqlite3.connect(path)
48
  cur = con.cursor()
49
  cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
50
- cur.executemany("INSERT INTO users(id,name,spend) VALUES(?,?,?)",
51
- [(1,"Alice",120.5),(2,"Bob",80.0),(3,"Carol",155.0)])
 
 
52
  con.commit()
53
  con.close()
54
 
@@ -86,7 +97,7 @@ def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
86
  for q in queries:
87
  t0 = time.perf_counter()
88
  r = pipeline.run(user_query=q, schema_preview=schema_preview)
89
- latency_ms = (time.perf_counter()-t0)*1000
90
  ok = (not r.get("ambiguous")) and ("error" not in r)
91
 
92
  traces = r.get("traces", [])
@@ -97,15 +108,19 @@ def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
97
  except Exception:
98
  pass
99
 
100
- results.append({
101
- "query": q,
102
- "exec_acc": 1.0 if ok else 0.0,
103
- "safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
104
- "latency_ms": latency_ms,
105
- "cost_usd": cost_sum,
106
- "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
107
- "provider": pipeline.generator.llm.provider_id if hasattr(pipeline.generator, "llm") else "unknown",
108
- })
 
 
 
 
109
 
110
  outfile.parent.mkdir(parents=True, exist_ok=True)
111
  with open(outfile, "w") as f:
@@ -118,10 +133,14 @@ def main():
118
  parser = argparse.ArgumentParser()
119
  parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
120
  parser.add_argument("--db", default="data/bench_demo.db")
121
- parser.add_argument("--use-openai", action="store_true", help="Use OpenAI provider if API key present")
 
 
 
 
122
  args = parser.parse_args()
123
 
124
- ROOT = Path(__file__).resolve().parents[1] # project root
125
  outfile = (ROOT / args.outfile).resolve()
126
  db_path = (ROOT / args.db).resolve()
127
 
 
20
  from adapters.db.sqlite_adapter import SQLiteAdapter
21
  from adapters.llm.openai_provider import OpenAIProvider
22
 
23
+
24
  # ---- fallbacks: Dummy LLM (so it runs without API keys)
25
  class DummyLLM:
26
  provider_id = "dummy-llm"
 
29
  text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
30
  return text, 0, 0, 0.0
31
 
32
+ def generate_sql(
33
+ self,
34
+ *,
35
+ user_query: str,
36
+ schema_preview: str,
37
+ plan_text: str,
38
+ clarify_answers=None,
39
+ ):
40
  # naive demo SQL (so pipeline flows end-to-end)
41
  sql = "SELECT 1 AS one;"
42
  rationale = "Demo SQL from DummyLLM"
 
51
  if path.exists():
52
  return
53
  import sqlite3
54
+
55
  path.parent.mkdir(parents=True, exist_ok=True)
56
  con = sqlite3.connect(path)
57
  cur = con.cursor()
58
  cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
59
+ cur.executemany(
60
+ "INSERT INTO users(id,name,spend) VALUES(?,?,?)",
61
+ [(1, "Alice", 120.5), (2, "Bob", 80.0), (3, "Carol", 155.0)],
62
+ )
63
  con.commit()
64
  con.close()
65
 
 
97
  for q in queries:
98
  t0 = time.perf_counter()
99
  r = pipeline.run(user_query=q, schema_preview=schema_preview)
100
+ latency_ms = (time.perf_counter() - t0) * 1000
101
  ok = (not r.get("ambiguous")) and ("error" not in r)
102
 
103
  traces = r.get("traces", [])
 
108
  except Exception:
109
  pass
110
 
111
+ results.append(
112
+ {
113
+ "query": q,
114
+ "exec_acc": 1.0 if ok else 0.0,
115
+ "safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
116
+ "latency_ms": latency_ms,
117
+ "cost_usd": cost_sum,
118
+ "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
119
+ "provider": pipeline.generator.llm.provider_id
120
+ if hasattr(pipeline.generator, "llm")
121
+ else "unknown",
122
+ }
123
+ )
124
 
125
  outfile.parent.mkdir(parents=True, exist_ok=True)
126
  with open(outfile, "w") as f:
 
133
  parser = argparse.ArgumentParser()
134
  parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
135
  parser.add_argument("--db", default="data/bench_demo.db")
136
+ parser.add_argument(
137
+ "--use-openai",
138
+ action="store_true",
139
+ help="Use OpenAI provider if API key present",
140
+ )
141
  args = parser.parse_args()
142
 
143
+ ROOT = Path(__file__).resolve().parents[1] # project root
144
  outfile = (ROOT / args.outfile).resolve()
145
  db_path = (ROOT / args.db).resolve()
146
 
benchmarks/spider_loader.py CHANGED
@@ -4,9 +4,8 @@ from dataclasses import dataclass
4
  from typing import List, Optional
5
  import os
6
 
7
- SPIDER_ROOT = pathlib.Path(
8
- os.getenv("SPIDER_ROOT", "data/spider")
9
- )
10
 
11
  @dataclass
12
  class SpiderItem:
@@ -15,7 +14,10 @@ class SpiderItem:
15
  gold_sql: str
16
  db_path: pathlib.Path
17
 
18
- def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[SpiderItem]:
 
 
 
19
  fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
20
  json_path = SPIDER_ROOT / fn
21
  try:
@@ -23,7 +25,6 @@ def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[
23
  except Exception as e:
24
  raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
25
 
26
-
27
  out: list[SpiderItem] = []
28
  for ex in items[: (limit or len(items))]:
29
  db_id = ex["db_id"]
@@ -35,14 +36,16 @@ def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[
35
  db_id=db_id,
36
  question=ex["question"],
37
  gold_sql=ex["query"],
38
- db_path=db_path
39
  )
40
  )
41
  return out
42
 
43
 
44
- def open_readonly_connection(db_path: pathlib.Path, timeout: float=5.0) -> sqlite3.Connection:
 
 
45
  uri = f"file:{db_path}?mode=ro&uri=true"
46
  conn = sqlite3.connect(uri, uri=True, timeout=timeout)
47
  conn.row_factory = sqlite3.Row
48
- return conn
 
4
  from typing import List, Optional
5
  import os
6
 
7
+ SPIDER_ROOT = pathlib.Path(os.getenv("SPIDER_ROOT", "data/spider"))
8
+
 
9
 
10
  @dataclass
11
  class SpiderItem:
 
14
  gold_sql: str
15
  db_path: pathlib.Path
16
 
17
+
18
+ def load_spider_sqlite(
19
+ split: str = "dev", limit: Optional[int] = None
20
+ ) -> List[SpiderItem]:
21
  fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
22
  json_path = SPIDER_ROOT / fn
23
  try:
 
25
  except Exception as e:
26
  raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
27
 
 
28
  out: list[SpiderItem] = []
29
  for ex in items[: (limit or len(items))]:
30
  db_id = ex["db_id"]
 
36
  db_id=db_id,
37
  question=ex["question"],
38
  gold_sql=ex["query"],
39
+ db_path=db_path,
40
  )
41
  )
42
  return out
43
 
44
 
45
+ def open_readonly_connection(
46
+ db_path: pathlib.Path, timeout: float = 5.0
47
+ ) -> sqlite3.Connection:
48
  uri = f"file:{db_path}?mode=ro&uri=true"
49
  conn = sqlite3.connect(uri, uri=True, timeout=timeout)
50
  conn.row_factory = sqlite3.Row
51
+ return conn
config.py CHANGED
@@ -4,12 +4,15 @@ from dotenv import load_dotenv
4
  load_dotenv()
5
 
6
 
7
- def get_env_var(name: str, required: bool = True, default: str | None = None) -> str | None:
 
 
8
  val = os.getenv(name, default)
9
  if required and not val:
10
  raise ValueError(f"Missing required environment variable: {name}")
11
  return val
12
 
 
13
  proxy_key = os.getenv("PROXY_API_KEY")
14
  proxy_base = os.getenv("PROXY_BASE_URL")
15
  openai_key = os.getenv("OPENAI_API_KEY")
@@ -17,7 +20,9 @@ openai_base = os.getenv("OPENAI_BASE_URL")
17
 
18
  api_key = proxy_key or openai_key
19
  if not api_key:
20
- raise ValueError("Missing API key: set PROXY_API_KEY or OPENAI_API_KEY in environment/secrets.")
 
 
21
 
22
  base_url = proxy_base or openai_base or "https://api.openai.com/v1"
23
 
@@ -33,11 +38,24 @@ LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # or gpt-4o, gpt-4o-mini
33
  LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
34
 
35
  FORBIDDEN_KEYWORDS = {
36
- "ATTACH", "PRAGMA",
37
- "CREATE", "DROP", "ALTER", "VACUUM", "REINDEX", "TRIGGER",
38
- "INSERT", "UPDATE", "DELETE", "REPLACE",
39
- "GRANT", "REVOKE",
40
- "BEGIN", "END", "COMMIT", "ROLLBACK",
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  "DETACH",
42
  }
43
  FORBIDDEN_TABLES = {"sqlite_master", "sqlite_temp_master"}
 
4
  load_dotenv()
5
 
6
 
7
+ def get_env_var(
8
+ name: str, required: bool = True, default: str | None = None
9
+ ) -> str | None:
10
  val = os.getenv(name, default)
11
  if required and not val:
12
  raise ValueError(f"Missing required environment variable: {name}")
13
  return val
14
 
15
+
16
  proxy_key = os.getenv("PROXY_API_KEY")
17
  proxy_base = os.getenv("PROXY_BASE_URL")
18
  openai_key = os.getenv("OPENAI_API_KEY")
 
20
 
21
  api_key = proxy_key or openai_key
22
  if not api_key:
23
+ raise ValueError(
24
+ "Missing API key: set PROXY_API_KEY or OPENAI_API_KEY in environment/secrets."
25
+ )
26
 
27
  base_url = proxy_base or openai_base or "https://api.openai.com/v1"
28
 
 
38
  LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
39
 
40
  FORBIDDEN_KEYWORDS = {
41
+ "ATTACH",
42
+ "PRAGMA",
43
+ "CREATE",
44
+ "DROP",
45
+ "ALTER",
46
+ "VACUUM",
47
+ "REINDEX",
48
+ "TRIGGER",
49
+ "INSERT",
50
+ "UPDATE",
51
+ "DELETE",
52
+ "REPLACE",
53
+ "GRANT",
54
+ "REVOKE",
55
+ "BEGIN",
56
+ "END",
57
+ "COMMIT",
58
+ "ROLLBACK",
59
  "DETACH",
60
  }
61
  FORBIDDEN_TABLES = {"sqlite_master", "sqlite_temp_master"}
nl2sql/ambiguity_detector.py CHANGED
@@ -1,16 +1,17 @@
1
  import re
2
  from typing import List
3
 
 
4
  class AmbiguityDetector:
5
  """Lightweight AmbiSQL-style ambiguity detection."""
6
 
7
  AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
8
 
9
- def detect(self, query:str, schema_preview: str) -> list[str]:
10
  hits = []
11
  q_lower = query.lower()
12
  for term in self.AMBIGUOUS_TERMS:
13
  if re.search(rf"\b{term}\b", q_lower):
14
  hits.append(f"The term '{term}' is ambiguous in this query.'")
15
 
16
- return hits
 
1
  import re
2
  from typing import List
3
 
4
+
5
  class AmbiguityDetector:
6
  """Lightweight AmbiSQL-style ambiguity detection."""
7
 
8
  AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
9
 
10
+ def detect(self, query: str, schema_preview: str) -> list[str]:
11
  hits = []
12
  q_lower = query.lower()
13
  for term in self.AMBIGUOUS_TERMS:
14
  if re.search(rf"\b{term}\b", q_lower):
15
  hits.append(f"The term '{term}' is ambiguous in this query.'")
16
 
17
+ return hits
nl2sql/executor.py CHANGED
@@ -2,6 +2,7 @@ import time
2
  from nl2sql.types import StageResult, StageTrace
3
  from adapters.db.base import DBAdapter
4
 
 
5
  class Executor:
6
  name = "executor"
7
 
@@ -12,10 +13,18 @@ class Executor:
12
  t0 = time.perf_counter()
13
  try:
14
  rows, cols = self.db.execute(sql)
15
- trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
16
- notes={"row_count": len(rows), "col_count": len(cols)})
17
- return StageResult(ok=True, data={"rows": rows, "columns": cols}, trace=trace)
 
 
 
 
 
18
  except Exception as e:
19
- trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
20
- notes={"error": str(e)})
 
 
 
21
  return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
 
2
  from nl2sql.types import StageResult, StageTrace
3
  from adapters.db.base import DBAdapter
4
 
5
+
6
  class Executor:
7
  name = "executor"
8
 
 
13
  t0 = time.perf_counter()
14
  try:
15
  rows, cols = self.db.execute(sql)
16
+ trace = StageTrace(
17
+ stage=self.name,
18
+ duration_ms=(time.perf_counter() - t0) * 1000,
19
+ notes={"row_count": len(rows), "col_count": len(cols)},
20
+ )
21
+ return StageResult(
22
+ ok=True, data={"rows": rows, "columns": cols}, trace=trace
23
+ )
24
  except Exception as e:
25
+ trace = StageTrace(
26
+ stage=self.name,
27
+ duration_ms=(time.perf_counter() - t0) * 1000,
28
+ notes={"error": str(e)},
29
+ )
30
  return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
nl2sql/generator.py CHANGED
@@ -4,34 +4,48 @@ from typing import Optional, Dict, Any
4
  from nl2sql.types import StageResult, StageTrace
5
  from adapters.llm.base import LLMProvider
6
 
 
7
  class Generator:
8
  name = "generator"
9
 
10
  def __init__(self, llm: LLMProvider) -> None:
11
  self.llm = llm
12
 
13
- def run(self, *, user_query: str, schema_preview: str, plan_text: str,
14
- clarify_answers: Optional[Dict[str, Any]] = None) -> StageResult:
 
 
 
 
 
 
15
  t0 = time.perf_counter()
16
  try:
17
  res = self.llm.generate_sql(
18
  user_query=user_query,
19
  schema_preview=schema_preview,
20
  plan_text=plan_text,
21
- clarify_answers=clarify_answers or {}
22
  )
23
  except Exception as e:
24
  return StageResult(ok=False, error=[f"Generator failed: {e}"])
25
 
26
  # Expect a 5-tuple
27
  if not isinstance(res, tuple) or len(res) != 5:
28
- return StageResult(ok=False, error=["Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"])
 
 
 
 
 
29
 
30
  sql, rationale, t_in, t_out, cost = res
31
 
32
  # Type/shape checks
33
  if not isinstance(sql, str) or not sql.strip():
34
- return StageResult(ok=False, error=["Generator produced empty or non-string SQL"])
 
 
35
  if not sql.lower().lstrip().startswith("select"):
36
  return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
37
 
@@ -45,5 +59,6 @@ class Generator:
45
  notes={"rationale_len": len(rationale)},
46
  )
47
 
48
- return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
49
-
 
 
4
  from nl2sql.types import StageResult, StageTrace
5
  from adapters.llm.base import LLMProvider
6
 
7
+
8
  class Generator:
9
  name = "generator"
10
 
11
  def __init__(self, llm: LLMProvider) -> None:
12
  self.llm = llm
13
 
14
+ def run(
15
+ self,
16
+ *,
17
+ user_query: str,
18
+ schema_preview: str,
19
+ plan_text: str,
20
+ clarify_answers: Optional[Dict[str, Any]] = None,
21
+ ) -> StageResult:
22
  t0 = time.perf_counter()
23
  try:
24
  res = self.llm.generate_sql(
25
  user_query=user_query,
26
  schema_preview=schema_preview,
27
  plan_text=plan_text,
28
+ clarify_answers=clarify_answers or {},
29
  )
30
  except Exception as e:
31
  return StageResult(ok=False, error=[f"Generator failed: {e}"])
32
 
33
  # Expect a 5-tuple
34
  if not isinstance(res, tuple) or len(res) != 5:
35
+ return StageResult(
36
+ ok=False,
37
+ error=[
38
+ "Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"
39
+ ],
40
+ )
41
 
42
  sql, rationale, t_in, t_out, cost = res
43
 
44
  # Type/shape checks
45
  if not isinstance(sql, str) or not sql.strip():
46
+ return StageResult(
47
+ ok=False, error=["Generator produced empty or non-string SQL"]
48
+ )
49
  if not sql.lower().lstrip().startswith("select"):
50
  return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
51
 
 
59
  notes={"rationale_len": len(rationale)},
60
  )
61
 
62
+ return StageResult(
63
+ ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
64
+ )
nl2sql/pipeline.py CHANGED
@@ -17,14 +17,17 @@ class Pipeline:
17
  All stages return structured traces and errors but final result is JSON-safe dict.
18
  """
19
 
20
- def __init__(self, *,
21
- detector: AmbiguityDetector,
22
- planner: Planner,
23
- generator: Generator,
24
- safety: Safety,
25
- executor: Executor,
26
- verifier: Verifier,
27
- repair: Repair):
 
 
 
28
  self.detector = detector
29
  self.planner = planner
30
  self.generator = generator
@@ -59,8 +62,13 @@ class Pipeline:
59
  return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
60
 
61
  # ------------------------------------------------------------
62
- def run(self, *, user_query: str, schema_preview: str,
63
- clarify_answers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
 
 
 
 
 
64
  """
65
  Always returns:
66
  {
@@ -86,26 +94,45 @@ class Pipeline:
86
  "error": False,
87
  "details": [f"Ambiguities found: {len(questions)}"],
88
  "questions": questions,
89
- "traces": []
90
  }
91
  except Exception as e:
92
- return {"ambiguous": True, "error": True, "details": [f"Detector failed: {e}"], "traces": []}
 
 
 
 
 
93
 
94
  # --- 2) planner
95
- r_plan = self._safe_stage(self.planner.run, user_query=user_query, schema_preview=schema_preview)
 
 
96
  traces.extend(self._trace_list(r_plan))
97
  if not r_plan.ok:
98
- return {"ambiguous": False, "error": True, "details": r_plan.errors, "traces": traces}
 
 
 
 
 
99
 
100
  # --- 3) generator
101
- r_gen = self._safe_stage(self.generator.run,
102
- user_query=user_query,
103
- schema_preview=schema_preview,
104
- plan_text=r_plan.data.get("plan"),
105
- clarify_answers=clarify_answers or {})
 
 
106
  traces.extend(self._trace_list(r_gen))
107
  if not r_gen.ok:
108
- return {"ambiguous": False, "error": True, "details": r_gen.errors, "traces": traces}
 
 
 
 
 
109
  sql = r_gen.data.get("sql")
110
  rationale = r_gen.data.get("rationale")
111
 
@@ -113,7 +140,12 @@ class Pipeline:
113
  r_safe = self._safe_stage(self.safety.check, sql=sql)
114
  traces.extend(self._trace_list(r_safe))
115
  if not r_safe.ok:
116
- return {"ambiguous": False, "error": True, "details": r_safe.errors, "traces": traces}
 
 
 
 
 
117
 
118
  # --- 5) executor
119
  r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
@@ -129,10 +161,12 @@ class Pipeline:
129
  # --- 7) repair loop if verification failed
130
  if not verified:
131
  for attempt in range(2):
132
- r_fix = self._safe_stage(self.repair.run,
133
- sql=sql,
134
- error_msg="; ".join(details or ["unknown"]),
135
- schema_preview=schema_preview)
 
 
136
  traces.extend(self._trace_list(r_fix))
137
  if not r_fix.ok:
138
  break
 
17
  All stages return structured traces and errors but final result is JSON-safe dict.
18
  """
19
 
20
+ def __init__(
21
+ self,
22
+ *,
23
+ detector: AmbiguityDetector,
24
+ planner: Planner,
25
+ generator: Generator,
26
+ safety: Safety,
27
+ executor: Executor,
28
+ verifier: Verifier,
29
+ repair: Repair,
30
+ ):
31
  self.detector = detector
32
  self.planner = planner
33
  self.generator = generator
 
62
  return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
63
 
64
  # ------------------------------------------------------------
65
+ def run(
66
+ self,
67
+ *,
68
+ user_query: str,
69
+ schema_preview: str,
70
+ clarify_answers: Optional[Dict[str, Any]] = None,
71
+ ) -> Dict[str, Any]:
72
  """
73
  Always returns:
74
  {
 
94
  "error": False,
95
  "details": [f"Ambiguities found: {len(questions)}"],
96
  "questions": questions,
97
+ "traces": [],
98
  }
99
  except Exception as e:
100
+ return {
101
+ "ambiguous": True,
102
+ "error": True,
103
+ "details": [f"Detector failed: {e}"],
104
+ "traces": [],
105
+ }
106
 
107
  # --- 2) planner
108
+ r_plan = self._safe_stage(
109
+ self.planner.run, user_query=user_query, schema_preview=schema_preview
110
+ )
111
  traces.extend(self._trace_list(r_plan))
112
  if not r_plan.ok:
113
+ return {
114
+ "ambiguous": False,
115
+ "error": True,
116
+ "details": r_plan.errors,
117
+ "traces": traces,
118
+ }
119
 
120
  # --- 3) generator
121
+ r_gen = self._safe_stage(
122
+ self.generator.run,
123
+ user_query=user_query,
124
+ schema_preview=schema_preview,
125
+ plan_text=r_plan.data.get("plan"),
126
+ clarify_answers=clarify_answers or {},
127
+ )
128
  traces.extend(self._trace_list(r_gen))
129
  if not r_gen.ok:
130
+ return {
131
+ "ambiguous": False,
132
+ "error": True,
133
+ "details": r_gen.errors,
134
+ "traces": traces,
135
+ }
136
  sql = r_gen.data.get("sql")
137
  rationale = r_gen.data.get("rationale")
138
 
 
140
  r_safe = self._safe_stage(self.safety.check, sql=sql)
141
  traces.extend(self._trace_list(r_safe))
142
  if not r_safe.ok:
143
+ return {
144
+ "ambiguous": False,
145
+ "error": True,
146
+ "details": r_safe.errors,
147
+ "traces": traces,
148
+ }
149
 
150
  # --- 5) executor
151
  r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
 
161
  # --- 7) repair loop if verification failed
162
  if not verified:
163
  for attempt in range(2):
164
+ r_fix = self._safe_stage(
165
+ self.repair.run,
166
+ sql=sql,
167
+ error_msg="; ".join(details or ["unknown"]),
168
+ schema_preview=schema_preview,
169
+ )
170
  traces.extend(self._trace_list(r_fix))
171
  if not r_fix.ok:
172
  break
nl2sql/planner.py CHANGED
@@ -3,14 +3,24 @@ import time
3
  from nl2sql.types import StageResult, StageTrace
4
  from adapters.llm.base import LLMProvider
5
 
 
6
  class Planner:
7
  name = "planner"
 
8
  def __init__(self, llm: LLMProvider) -> None:
9
  self.llm = llm
10
 
11
  def run(self, *, user_query: str, schema_preview: str) -> StageResult:
12
  t0 = time.perf_counter()
13
- plan_text, t_in, t_out, cost = self.llm.plan(user_query=user_query, schema_preview=schema_preview)
14
- trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
15
- token_in=t_in, token_out=t_out, cost_usd=cost, notes={"len_plan": len(plan_text)})
 
 
 
 
 
 
 
 
16
  return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
 
3
  from nl2sql.types import StageResult, StageTrace
4
  from adapters.llm.base import LLMProvider
5
 
6
+
7
  class Planner:
8
  name = "planner"
9
+
10
  def __init__(self, llm: LLMProvider) -> None:
11
  self.llm = llm
12
 
13
  def run(self, *, user_query: str, schema_preview: str) -> StageResult:
14
  t0 = time.perf_counter()
15
+ plan_text, t_in, t_out, cost = self.llm.plan(
16
+ user_query=user_query, schema_preview=schema_preview
17
+ )
18
+ trace = StageTrace(
19
+ stage=self.name,
20
+ duration_ms=(time.perf_counter() - t0) * 1000,
21
+ token_in=t_in,
22
+ token_out=t_out,
23
+ cost_usd=cost,
24
+ notes={"len_plan": len(plan_text)},
25
+ )
26
  return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
nl2sql/repair.py CHANGED
@@ -14,16 +14,26 @@ When repairing:
14
  Return only the corrected SQL.
15
  """
16
 
 
17
  class Repair:
18
  name = "repair"
 
19
  def __init__(self, llm: LLMProvider):
20
  self.llm = llm
21
 
22
- def run(self, sql:str, error_msg: str, schema_preview: str) -> StageResult:
23
  t0 = time.perf_counter()
24
- fixed_sql, t_in, t_out, cost = self.llm.repair(sql=sql, error_msg=f"{GUIDELINES}\n\n{error_msg}",
25
- schema_preview=schema_preview)
26
- trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
27
- token_in=t_in, token_out=t_out, cost_usd=cost,
28
- notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)})
 
 
 
 
 
 
 
 
29
  return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)
 
14
  Return only the corrected SQL.
15
  """
16
 
17
+
18
  class Repair:
19
  name = "repair"
20
+
21
  def __init__(self, llm: LLMProvider):
22
  self.llm = llm
23
 
24
+ def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
25
  t0 = time.perf_counter()
26
+ fixed_sql, t_in, t_out, cost = self.llm.repair(
27
+ sql=sql,
28
+ error_msg=f"{GUIDELINES}\n\n{error_msg}",
29
+ schema_preview=schema_preview,
30
+ )
31
+ trace = StageTrace(
32
+ stage=self.name,
33
+ duration_ms=(time.perf_counter() - t0) * 1000,
34
+ token_in=t_in,
35
+ token_out=t_out,
36
+ cost_usd=cost,
37
+ notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)},
38
+ )
39
  return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)
nl2sql/safety.py CHANGED
@@ -4,7 +4,7 @@ from nl2sql.types import StageResult, StageTrace
4
 
5
  # --- Regex utils ---
6
  _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
7
- _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
8
  # string literals (single & double quotes), allow escaped quotes
9
  _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
10
  _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
@@ -18,20 +18,24 @@ _FORBIDDEN = re.compile(
18
  # allow: SELECT ... or WITH <cte...> SELECT ...
19
  _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
20
 
 
21
  def _strip_comments(s: str) -> str:
22
  s = _COMMENT_BLOCK.sub(" ", s)
23
  s = _COMMENT_LINE.sub(" ", s)
24
  return s
25
 
 
26
  def _mask_strings(s: str) -> str:
27
  s = _STRING_SINGLE.sub("'X'", s)
28
  s = _STRING_DOUBLE.sub('"X"', s)
29
  return s
30
 
 
31
  def _split_statements(s: str) -> list[str]:
32
  parts = [p.strip() for p in s.split(";")]
33
  return [p for p in parts if p]
34
 
 
35
  class Safety:
36
  name = "safety"
37
 
@@ -46,7 +50,9 @@ class Safety:
46
  return StageResult(
47
  ok=False,
48
  error=["Multiple statements detected"],
49
- trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
 
 
50
  )
51
 
52
  body = stmts[0]
@@ -55,14 +61,18 @@ class Safety:
55
  return StageResult(
56
  ok=False,
57
  error=["Forbidden keyword detected"],
58
- trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
 
 
59
  )
60
 
61
  if not _ALLOW_SELECT.match(body):
62
  return StageResult(
63
  ok=False,
64
  error=["Non-SELECT statement"],
65
- trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
 
 
66
  )
67
 
68
  return StageResult(
@@ -71,5 +81,7 @@ class Safety:
71
  "sql": sql.strip(),
72
  "rationale": "Statement validated as SELECT-only (strings/comments ignored).",
73
  },
74
- trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
 
 
75
  )
 
4
 
5
  # --- Regex utils ---
6
  _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
7
+ _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
8
  # string literals (single & double quotes), allow escaped quotes
9
  _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
10
  _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
 
18
  # allow: SELECT ... or WITH <cte...> SELECT ...
19
  _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
20
 
21
+
22
  def _strip_comments(s: str) -> str:
23
  s = _COMMENT_BLOCK.sub(" ", s)
24
  s = _COMMENT_LINE.sub(" ", s)
25
  return s
26
 
27
+
28
  def _mask_strings(s: str) -> str:
29
  s = _STRING_SINGLE.sub("'X'", s)
30
  s = _STRING_DOUBLE.sub('"X"', s)
31
  return s
32
 
33
+
34
  def _split_statements(s: str) -> list[str]:
35
  parts = [p.strip() for p in s.split(";")]
36
  return [p for p in parts if p]
37
 
38
+
39
  class Safety:
40
  name = "safety"
41
 
 
50
  return StageResult(
51
  ok=False,
52
  error=["Multiple statements detected"],
53
+ trace=StageTrace(
54
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
55
+ ),
56
  )
57
 
58
  body = stmts[0]
 
61
  return StageResult(
62
  ok=False,
63
  error=["Forbidden keyword detected"],
64
+ trace=StageTrace(
65
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
66
+ ),
67
  )
68
 
69
  if not _ALLOW_SELECT.match(body):
70
  return StageResult(
71
  ok=False,
72
  error=["Non-SELECT statement"],
73
+ trace=StageTrace(
74
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
75
+ ),
76
  )
77
 
78
  return StageResult(
 
81
  "sql": sql.strip(),
82
  "rationale": "Statement validated as SELECT-only (strings/comments ignored).",
83
  },
84
+ trace=StageTrace(
85
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
86
+ ),
87
  )
nl2sql/stubs.py CHANGED
@@ -1,31 +1,37 @@
1
  from nl2sql.types import StageResult, StageTrace
2
 
 
3
  class NoOpExecutor:
4
  name = "executor"
 
5
  def run(self, sql: str) -> StageResult:
6
  # pretend success, return empty result set
7
  return StageResult(
8
  ok=True,
9
  data={"rows": [], "columns": []},
10
- trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
11
  )
12
 
 
13
  class NoOpVerifier:
14
  name = "verifier"
 
15
  def run(self, sql: str, exec_result: StageResult) -> StageResult:
16
  # always verified for legacy tests
17
  return StageResult(
18
  ok=True,
19
  data={"verified": True},
20
- trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
21
  )
22
 
 
23
  class NoOpRepair:
24
  name = "repair"
 
25
  def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
26
  # return original SQL unchanged
27
  return StageResult(
28
  ok=True,
29
  data={"sql": sql},
30
- trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
31
  )
 
1
  from nl2sql.types import StageResult, StageTrace
2
 
3
+
4
  class NoOpExecutor:
5
  name = "executor"
6
+
7
  def run(self, sql: str) -> StageResult:
8
  # pretend success, return empty result set
9
  return StageResult(
10
  ok=True,
11
  data={"rows": [], "columns": []},
12
+ trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True}),
13
  )
14
 
15
+
16
  class NoOpVerifier:
17
  name = "verifier"
18
+
19
  def run(self, sql: str, exec_result: StageResult) -> StageResult:
20
  # always verified for legacy tests
21
  return StageResult(
22
  ok=True,
23
  data={"verified": True},
24
+ trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True}),
25
  )
26
 
27
+
28
  class NoOpRepair:
29
  name = "repair"
30
+
31
  def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
32
  # return original SQL unchanged
33
  return StageResult(
34
  ok=True,
35
  data={"sql": sql},
36
+ trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True}),
37
  )
nl2sql/types.py CHANGED
@@ -1,6 +1,7 @@
1
  from dataclasses import dataclass
2
  from typing import Any, Dict, Optional, List
3
 
 
4
  @dataclass(frozen=True)
5
  class StageTrace:
6
  stage: str
@@ -10,6 +11,7 @@ class StageTrace:
10
  token_out: Optional[int] = None
11
  cost_usd: Optional[float] = None
12
 
 
13
  @dataclass(frozen=True)
14
  class StageResult:
15
  ok: bool
 
1
  from dataclasses import dataclass
2
  from typing import Any, Dict, Optional, List
3
 
4
+
5
  @dataclass(frozen=True)
6
  class StageTrace:
7
  stage: str
 
11
  token_out: Optional[int] = None
12
  cost_usd: Optional[float] = None
13
 
14
+
15
  @dataclass(frozen=True)
16
  class StageResult:
17
  ok: bool
nl2sql/verifier.py CHANGED
@@ -2,15 +2,20 @@ import sqlglot
2
  from sqlglot import expressions as exp
3
  from nl2sql.types import StageResult, StageTrace
4
 
 
5
  class Verifier:
6
  name = "verifier"
7
 
8
  def run(self, sql: str, exec_result: StageResult) -> StageResult:
9
  if not exec_result.ok:
10
- return StageResult(ok=False, data=None,
11
- trace=StageTrace(stage=self.name, duration_ms=0,
12
- notes={"reason": "execution_error"}),
13
- error=exec_result.errors)
 
 
 
 
14
 
15
  # Rule 1: check SELECT / GROUP consistency
16
  issues = []
@@ -25,9 +30,16 @@ class Verifier:
25
  issues.append(f"Parse error during verification: {e}")
26
 
27
  if issues:
28
- return StageResult(ok=False, data=None,
29
- trace=StageTrace(stage=self.name, duration_ms=0,
30
- notes={"issues": issues}),
31
- error=issues)
32
- return StageResult(ok=True, data={"verified": True},
33
- trace=StageTrace(stage=self.name, duration_ms=0))
 
 
 
 
 
 
 
 
2
  from sqlglot import expressions as exp
3
  from nl2sql.types import StageResult, StageTrace
4
 
5
+
6
  class Verifier:
7
  name = "verifier"
8
 
9
  def run(self, sql: str, exec_result: StageResult) -> StageResult:
10
  if not exec_result.ok:
11
+ return StageResult(
12
+ ok=False,
13
+ data=None,
14
+ trace=StageTrace(
15
+ stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
16
+ ),
17
+ error=exec_result.errors,
18
+ )
19
 
20
  # Rule 1: check SELECT / GROUP consistency
21
  issues = []
 
30
  issues.append(f"Parse error during verification: {e}")
31
 
32
  if issues:
33
+ return StageResult(
34
+ ok=False,
35
+ data=None,
36
+ trace=StageTrace(
37
+ stage=self.name, duration_ms=0, notes={"issues": issues}
38
+ ),
39
+ error=issues,
40
+ )
41
+ return StageResult(
42
+ ok=True,
43
+ data={"verified": True},
44
+ trace=StageTrace(stage=self.name, duration_ms=0),
45
+ )
tests/conftest.py CHANGED
@@ -4,4 +4,4 @@ from dotenv import load_dotenv
4
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
  ENV_PATH = os.path.join(ROOT_DIR, ".env")
6
 
7
- load_dotenv(dotenv_path=ENV_PATH)
 
4
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
  ENV_PATH = os.path.join(ROOT_DIR, ".env")
6
 
7
+ load_dotenv(dotenv_path=ENV_PATH)
tests/test_ambiguity.py CHANGED
@@ -2,18 +2,23 @@ from nl2sql.ambiguity_detector import AmbiguityDetector
2
  from nl2sql.types import StageResult
3
  from app.routers import nl2sql
4
 
 
5
  def test_detects_ambiguous_terms():
6
  det = AmbiguityDetector()
7
  res = det.detect("Show me recent top singers", "table: singer(id,name,age)")
8
  assert len(res) >= 1
9
  assert "recent" in res[0].lower()
10
 
 
11
  def test_not_false_positive():
12
  det = AmbiguityDetector()
13
  res = det.detect("List all singers older than 30", "table: singer(id, name, age)")
14
  assert res == []
15
 
 
16
  def test_ambiguity_response():
17
- fake_result = StageResult(ok=True, data={"ambiguous": True, "questions": ["Clarify column?"]})
 
 
18
  response = nl2sql._to_dict(fake_result.data)
19
- assert response["ambiguous"] is True
 
2
  from nl2sql.types import StageResult
3
  from app.routers import nl2sql
4
 
5
+
6
  def test_detects_ambiguous_terms():
7
  det = AmbiguityDetector()
8
  res = det.detect("Show me recent top singers", "table: singer(id,name,age)")
9
  assert len(res) >= 1
10
  assert "recent" in res[0].lower()
11
 
12
+
13
  def test_not_false_positive():
14
  det = AmbiguityDetector()
15
  res = det.detect("List all singers older than 30", "table: singer(id, name, age)")
16
  assert res == []
17
 
18
+
19
  def test_ambiguity_response():
20
+ fake_result = StageResult(
21
+ ok=True, data={"ambiguous": True, "questions": ["Clarify column?"]}
22
+ )
23
  response = nl2sql._to_dict(fake_result.data)
24
+ assert response["ambiguous"] is True
tests/test_executor.py CHANGED
@@ -1,9 +1,11 @@
1
  from nl2sql.executor import Executor
2
  from adapters.db.sqlite_adapter import SQLiteAdapter
3
 
 
4
  def test_executor_runs_select(tmp_path):
5
  db_path = tmp_path / "test.db"
6
  import sqlite3
 
7
  conn = sqlite3.connect(db_path)
8
  conn.execute("CREATE TABLE users(id INT, name TEXT);")
9
  conn.execute("INSERT INTO users VALUES (1, 'Alice');")
 
1
  from nl2sql.executor import Executor
2
  from adapters.db.sqlite_adapter import SQLiteAdapter
3
 
4
+
5
  def test_executor_runs_select(tmp_path):
6
  db_path = tmp_path / "test.db"
7
  import sqlite3
8
+
9
  conn = sqlite3.connect(db_path)
10
  conn.execute("CREATE TABLE users(id INT, name TEXT);")
11
  conn.execute("INSERT INTO users VALUES (1, 'Alice');")
tests/test_generator.py CHANGED
@@ -5,6 +5,7 @@ from nl2sql.types import StageResult
5
 
6
  # --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
7
 
 
8
  class LLM_OK:
9
  def generate_sql(self, **kwargs):
10
  # contract: (sql, rationale, t_in, t_out, cost)
@@ -37,11 +38,12 @@ class LLM_CONTRACT_SHORT:
37
 
38
  # --- Parametrized negative cases ----------------------------------------------
39
 
 
40
  @pytest.mark.parametrize(
41
  "llm, err_keyword",
42
  [
43
- (LLM_EMPTY_SQL(), "empty"), # empty or non-string sql
44
- (LLM_NON_SELECT(), "non-select"), # generated non-SELECT
45
  (LLM_CONTRACT_NONE(), "contract violation"),
46
  (LLM_CONTRACT_SHORT(), "contract violation"),
47
  ],
@@ -52,7 +54,7 @@ def test_generator_errors_do_not_create_trace(llm, err_keyword):
52
  user_query="show all singers",
53
  schema_preview="CREATE TABLE singer(id int, name text);",
54
  plan_text="-- plan --",
55
- clarify_answers={}
56
  )
57
  assert isinstance(r, StageResult)
58
  assert r.ok is False
@@ -65,13 +67,14 @@ def test_generator_errors_do_not_create_trace(llm, err_keyword):
65
 
66
  # --- Positive case (success) ---------------------------------------------------
67
 
 
68
  def test_generator_success_has_valid_trace_and_data():
69
  gen = Generator(llm=LLM_OK())
70
  r = gen.run(
71
  user_query="show all singers",
72
  schema_preview="CREATE TABLE singer(id int, name text);",
73
  plan_text="-- plan --",
74
- clarify_answers={}
75
  )
76
 
77
  # Basic success checks
 
5
 
6
  # --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
7
 
8
+
9
  class LLM_OK:
10
  def generate_sql(self, **kwargs):
11
  # contract: (sql, rationale, t_in, t_out, cost)
 
38
 
39
  # --- Parametrized negative cases ----------------------------------------------
40
 
41
+
42
  @pytest.mark.parametrize(
43
  "llm, err_keyword",
44
  [
45
+ (LLM_EMPTY_SQL(), "empty"), # empty or non-string sql
46
+ (LLM_NON_SELECT(), "non-select"), # generated non-SELECT
47
  (LLM_CONTRACT_NONE(), "contract violation"),
48
  (LLM_CONTRACT_SHORT(), "contract violation"),
49
  ],
 
54
  user_query="show all singers",
55
  schema_preview="CREATE TABLE singer(id int, name text);",
56
  plan_text="-- plan --",
57
+ clarify_answers={},
58
  )
59
  assert isinstance(r, StageResult)
60
  assert r.ok is False
 
67
 
68
  # --- Positive case (success) ---------------------------------------------------
69
 
70
+
71
  def test_generator_success_has_valid_trace_and_data():
72
  gen = Generator(llm=LLM_OK())
73
  r = gen.run(
74
  user_query="show all singers",
75
  schema_preview="CREATE TABLE singer(id int, name text);",
76
  plan_text="-- plan --",
77
+ clarify_answers={},
78
  )
79
 
80
  # Basic success checks
tests/test_nl2sql_router.py CHANGED
@@ -9,8 +9,10 @@ client = TestClient(app)
9
  def fake_trace(stage: str):
10
  return StageTrace(stage=stage, duration_ms=10.0)
11
 
 
12
  path = app.url_path_for("nl2sql_handler")
13
 
 
14
  # --- 1) Clarify / ambiguity case ---------------------------------------------
15
  def test_ambiguity_route(monkeypatch):
16
  from app.routers import nl2sql
@@ -47,7 +49,9 @@ def test_error_route(monkeypatch):
47
  from app.routers import nl2sql
48
 
49
  def fake_run(*args, **kwargs):
50
- return StageResult(ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]})
 
 
51
 
52
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
53
 
 
9
  def fake_trace(stage: str):
10
  return StageTrace(stage=stage, duration_ms=10.0)
11
 
12
+
13
  path = app.url_path_for("nl2sql_handler")
14
 
15
+
16
  # --- 1) Clarify / ambiguity case ---------------------------------------------
17
  def test_ambiguity_route(monkeypatch):
18
  from app.routers import nl2sql
 
49
  from app.routers import nl2sql
50
 
51
  def fake_run(*args, **kwargs):
52
+ return StageResult(
53
+ ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]}
54
+ )
55
 
56
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
57
 
tests/test_openai_provider.py CHANGED
@@ -6,21 +6,23 @@ from adapters.llm.openai_provider import OpenAIProvider
6
  # Helper class to fake the completion object returned by OpenAI SDK
7
  class FakeCompletion:
8
  def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
9
- self.choices = [type("Choice", (), {"message": type("Msg", (), {"content": content})})]
10
- self.usage = type("Usage", (), {
11
- "prompt_tokens": prompt_tokens,
12
- "completion_tokens": completion_tokens
13
- })
 
 
 
14
 
15
 
16
  # --- Case 1: clean valid JSON --------------------------------------------------
17
  def test_generate_sql_valid_json(monkeypatch):
18
  provider = OpenAIProvider()
19
 
20
- fake_content = json.dumps({
21
- "sql": "SELECT * FROM singer;",
22
- "rationale": "List all singers."
23
- })
24
  fake_completion = FakeCompletion(fake_content)
25
 
26
  # Monkeypatch client.chat.completions.create
@@ -33,7 +35,7 @@ def test_generate_sql_valid_json(monkeypatch):
33
  user_query="show all singers",
34
  schema_preview="CREATE TABLE singer(id int, name text);",
35
  plan_text="-- plan --",
36
- clarify_answers={}
37
  )
38
 
39
  assert sql.strip().lower().startswith("select")
@@ -48,7 +50,7 @@ def test_generate_sql_recover_from_partial_json(monkeypatch):
48
  provider = OpenAIProvider()
49
 
50
  # invalid JSON with text around it
51
- fake_content = "Here is the result:\n{ \"sql\": \"SELECT * FROM users;\", \"rationale\": \"list users\" }\nThanks!"
52
  fake_completion = FakeCompletion(fake_content)
53
 
54
  def fake_create(*args, **kwargs):
@@ -59,7 +61,7 @@ def test_generate_sql_recover_from_partial_json(monkeypatch):
59
  sql, rationale, *_ = provider.generate_sql(
60
  user_query="show all users",
61
  schema_preview="CREATE TABLE users(id int, name text);",
62
- plan_text="-- plan --"
63
  )
64
 
65
  assert sql.lower().startswith("select")
@@ -83,5 +85,5 @@ def test_generate_sql_invalid_json(monkeypatch):
83
  provider.generate_sql(
84
  user_query="show X",
85
  schema_preview="CREATE TABLE t(id int);",
86
- plan_text="-- plan --"
87
  )
 
6
  # Helper class to fake the completion object returned by OpenAI SDK
7
  class FakeCompletion:
8
  def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
9
+ self.choices = [
10
+ type("Choice", (), {"message": type("Msg", (), {"content": content})})
11
+ ]
12
+ self.usage = type(
13
+ "Usage",
14
+ (),
15
+ {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens},
16
+ )
17
 
18
 
19
  # --- Case 1: clean valid JSON --------------------------------------------------
20
  def test_generate_sql_valid_json(monkeypatch):
21
  provider = OpenAIProvider()
22
 
23
+ fake_content = json.dumps(
24
+ {"sql": "SELECT * FROM singer;", "rationale": "List all singers."}
25
+ )
 
26
  fake_completion = FakeCompletion(fake_content)
27
 
28
  # Monkeypatch client.chat.completions.create
 
35
  user_query="show all singers",
36
  schema_preview="CREATE TABLE singer(id int, name text);",
37
  plan_text="-- plan --",
38
+ clarify_answers={},
39
  )
40
 
41
  assert sql.strip().lower().startswith("select")
 
50
  provider = OpenAIProvider()
51
 
52
  # invalid JSON with text around it
53
+ fake_content = 'Here is the result:\n{ "sql": "SELECT * FROM users;", "rationale": "list users" }\nThanks!'
54
  fake_completion = FakeCompletion(fake_content)
55
 
56
  def fake_create(*args, **kwargs):
 
61
  sql, rationale, *_ = provider.generate_sql(
62
  user_query="show all users",
63
  schema_preview="CREATE TABLE users(id int, name text);",
64
+ plan_text="-- plan --",
65
  )
66
 
67
  assert sql.lower().startswith("select")
 
85
  provider.generate_sql(
86
  user_query="show X",
87
  schema_preview="CREATE TABLE t(id int);",
88
+ plan_text="-- plan --",
89
  )
tests/test_pipeline_integration.py CHANGED
@@ -5,8 +5,10 @@ from nl2sql.types import StageResult, StageTrace
5
 
6
  # --- Dummy stages to isolate pipeline -----------------------------------------
7
 
 
8
  class DummyDetector:
9
  """Simulates ambiguity detector stage."""
 
10
  def __init__(self, ambiguous=False):
11
  self.ambiguous = ambiguous
12
 
@@ -17,6 +19,7 @@ class DummyDetector:
17
 
18
  class DummyPlanner:
19
  """Simulates planner stage."""
 
20
  def run(self, *, user_query, schema_preview):
21
  trace = StageTrace(stage="planner", duration_ms=1.0)
22
  if "fail_plan" in user_query:
@@ -26,17 +29,21 @@ class DummyPlanner:
26
 
27
  class DummyGenerator:
28
  """Simulates generator stage."""
 
29
  def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
30
  trace = StageTrace(stage="generator", duration_ms=1.0)
31
  if "fail_gen" in user_query:
32
  return StageResult(ok=False, error=["Generator failed"], trace=trace)
33
  sql = "SELECT * FROM singer;"
34
  rationale = "List all singers."
35
- return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
 
 
36
 
37
 
38
  class DummySafety:
39
  """Simulates safety stage."""
 
40
  def check(self, sql):
41
  trace = StageTrace(stage="safety", duration_ms=1.0)
42
  if "DROP" in sql.upper():
@@ -50,12 +57,12 @@ def test_pipeline_success():
50
  detector=DummyDetector(ambiguous=False),
51
  planner=DummyPlanner(),
52
  generator=DummyGenerator(),
53
- safety=DummySafety()
54
  )
55
 
56
  r = pipeline.run(
57
  user_query="show all singers",
58
- schema_preview="CREATE TABLE singer(id int, name text);"
59
  )
60
 
61
  assert isinstance(r, StageResult)
@@ -73,13 +80,10 @@ def test_pipeline_ambiguity():
73
  detector=DummyDetector(ambiguous=True),
74
  planner=DummyPlanner(),
75
  generator=DummyGenerator(),
76
- safety=DummySafety()
77
  )
78
 
79
- r = pipeline.run(
80
- user_query="show data",
81
- schema_preview="CREATE TABLE x(id int);"
82
- )
83
 
84
  assert isinstance(r, StageResult)
85
  assert r.ok is True
@@ -93,11 +97,10 @@ def test_pipeline_plan_fail():
93
  detector=DummyDetector(),
94
  planner=DummyPlanner(),
95
  generator=DummyGenerator(),
96
- safety=DummySafety()
97
  )
98
  r = pipeline.run(
99
- user_query="fail_plan",
100
- schema_preview="CREATE TABLE singer(id int);"
101
  )
102
  assert isinstance(r, StageResult)
103
  assert r.ok is False
@@ -110,11 +113,10 @@ def test_pipeline_gen_fail():
110
  detector=DummyDetector(),
111
  planner=DummyPlanner(),
112
  generator=DummyGenerator(),
113
- safety=DummySafety()
114
  )
115
  r = pipeline.run(
116
- user_query="fail_gen",
117
- schema_preview="CREATE TABLE singer(id int);"
118
  )
119
  assert r.ok is False
120
  assert "Generator failed" in " ".join(r.error or [])
@@ -126,17 +128,18 @@ def test_pipeline_safety_fail():
126
  def run(self, **kw):
127
  trace = StageTrace(stage="generator", duration_ms=1.0)
128
  # Generate a DROP TABLE → unsafe
129
- return StageResult(ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace)
 
 
130
 
131
  pipeline = Pipeline(
132
  detector=DummyDetector(),
133
  planner=DummyPlanner(),
134
  generator=UnsafeGen(),
135
- safety=DummySafety()
136
  )
137
  r = pipeline.run(
138
- user_query="drop something",
139
- schema_preview="CREATE TABLE x(id int);"
140
  )
141
  assert r.ok is False
142
  assert "unsafe" in " ".join(r.error or []).lower()
 
5
 
6
  # --- Dummy stages to isolate pipeline -----------------------------------------
7
 
8
+
9
  class DummyDetector:
10
  """Simulates ambiguity detector stage."""
11
+
12
  def __init__(self, ambiguous=False):
13
  self.ambiguous = ambiguous
14
 
 
19
 
20
  class DummyPlanner:
21
  """Simulates planner stage."""
22
+
23
  def run(self, *, user_query, schema_preview):
24
  trace = StageTrace(stage="planner", duration_ms=1.0)
25
  if "fail_plan" in user_query:
 
29
 
30
  class DummyGenerator:
31
  """Simulates generator stage."""
32
+
33
  def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
34
  trace = StageTrace(stage="generator", duration_ms=1.0)
35
  if "fail_gen" in user_query:
36
  return StageResult(ok=False, error=["Generator failed"], trace=trace)
37
  sql = "SELECT * FROM singer;"
38
  rationale = "List all singers."
39
+ return StageResult(
40
+ ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
41
+ )
42
 
43
 
44
  class DummySafety:
45
  """Simulates safety stage."""
46
+
47
  def check(self, sql):
48
  trace = StageTrace(stage="safety", duration_ms=1.0)
49
  if "DROP" in sql.upper():
 
57
  detector=DummyDetector(ambiguous=False),
58
  planner=DummyPlanner(),
59
  generator=DummyGenerator(),
60
+ safety=DummySafety(),
61
  )
62
 
63
  r = pipeline.run(
64
  user_query="show all singers",
65
+ schema_preview="CREATE TABLE singer(id int, name text);",
66
  )
67
 
68
  assert isinstance(r, StageResult)
 
80
  detector=DummyDetector(ambiguous=True),
81
  planner=DummyPlanner(),
82
  generator=DummyGenerator(),
83
+ safety=DummySafety(),
84
  )
85
 
86
+ r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
 
 
 
87
 
88
  assert isinstance(r, StageResult)
89
  assert r.ok is True
 
97
  detector=DummyDetector(),
98
  planner=DummyPlanner(),
99
  generator=DummyGenerator(),
100
+ safety=DummySafety(),
101
  )
102
  r = pipeline.run(
103
+ user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
 
104
  )
105
  assert isinstance(r, StageResult)
106
  assert r.ok is False
 
113
  detector=DummyDetector(),
114
  planner=DummyPlanner(),
115
  generator=DummyGenerator(),
116
+ safety=DummySafety(),
117
  )
118
  r = pipeline.run(
119
+ user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
 
120
  )
121
  assert r.ok is False
122
  assert "Generator failed" in " ".join(r.error or [])
 
128
  def run(self, **kw):
129
  trace = StageTrace(stage="generator", duration_ms=1.0)
130
  # Generate a DROP TABLE → unsafe
131
+ return StageResult(
132
+ ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace
133
+ )
134
 
135
  pipeline = Pipeline(
136
  detector=DummyDetector(),
137
  planner=DummyPlanner(),
138
  generator=UnsafeGen(),
139
+ safety=DummySafety(),
140
  )
141
  r = pipeline.run(
142
+ user_query="drop something", schema_preview="CREATE TABLE x(id int);"
 
143
  )
144
  assert r.ok is False
145
  assert "unsafe" in " ".join(r.error or []).lower()
tests/test_safety.py CHANGED
@@ -2,7 +2,6 @@ from nl2sql.safety import Safety
2
  import pytest
3
 
4
 
5
-
6
  def test_safety_allows_select():
7
  s = Safety()
8
  result = s.check("SELECT * FROM users;")
@@ -10,6 +9,7 @@ def test_safety_allows_select():
10
  assert "sql" in result.data
11
  assert result.trace.stage == "safety"
12
 
 
13
  def test_safety_allows_with_select_cte():
14
  s = Safety()
15
  sql = """
@@ -21,12 +21,14 @@ def test_safety_allows_with_select_cte():
21
  r = s.check(sql)
22
  assert r.ok
23
 
 
24
  def test_safety_allows_select_with_comments_and_newlines():
25
  s = Safety()
26
  sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
27
  r = s.check(sql)
28
  assert r.ok
29
 
 
30
  def test_safety_allows_keywords_inside_string_literals():
31
  s = Safety()
32
  sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
@@ -40,32 +42,39 @@ def test_safety_blocks_delete():
40
  assert not result.ok
41
  assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
42
 
43
- @pytest.mark.parametrize("sql", [
44
- "UPDATE users SET name='X' WHERE id=1;",
45
- "INSERT INTO users(id) VALUES (1);",
46
- "DROP TABLE users;",
47
- "CREATE TABLE x(id INT);",
48
- "ALTER TABLE users ADD COLUMN x INT;",
49
- "ATTACH DATABASE 'hack.db' AS h;",
50
- "PRAGMA journal_mode=WAL;",
51
- ])
 
 
 
 
52
  def test_safety_blocks_forbidden_statements(sql):
53
  s = Safety()
54
  res = s.check(sql)
55
  assert not res.ok
56
 
 
57
  def test_safety_blocks_stacked_delete_after_select():
58
  s = Safety()
59
  sql = "SELECT * FROM users; DELETE FROM users;"
60
  r = s.check(sql)
61
  assert not r.ok
62
 
 
63
  def test_safety_blocks_stacked_delete_with_spaces():
64
  s = Safety()
65
  sql = "SELECT * FROM users ; \n DELETE users;"
66
  r = s.check(sql)
67
  assert not r.ok
68
 
 
69
  def test_safety_blocks_delete_inside_cte():
70
  s = Safety()
71
  sql = """
@@ -75,26 +84,35 @@ def test_safety_blocks_delete_inside_cte():
75
  r = s.check(sql)
76
  assert not r.ok
77
 
78
- @pytest.mark.parametrize("sql", [
79
- "/*D*/ROP TABLE users;",
80
- "PR/*x*/AGMA journal_mode=WAL;",
81
- "AL/* comment */TER TABLE x ADD COLUMN y INT;",
82
- ])
 
 
 
 
83
  def test_safety_blocks_comment_obfuscation(sql):
84
  s = Safety()
85
  r = s.check(sql)
86
  assert not r.ok
87
 
88
- @pytest.mark.parametrize("sql", [
89
- "pragma journal_mode=WAL;", # lower-case
90
- " PRAGMA user_version = 5 ; ",
91
- "\nATTACH DATABASE 'hack.db' AS h;",
92
- ])
 
 
 
 
93
  def test_safety_blocks_forbidden_case_and_spacing(sql):
94
  s = Safety()
95
  r = s.check(sql)
96
  assert not r.ok
97
 
 
98
  def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
99
  s = Safety()
100
  sql = "SELECT 1; -- now do something bad\n"
 
2
  import pytest
3
 
4
 
 
5
  def test_safety_allows_select():
6
  s = Safety()
7
  result = s.check("SELECT * FROM users;")
 
9
  assert "sql" in result.data
10
  assert result.trace.stage == "safety"
11
 
12
+
13
  def test_safety_allows_with_select_cte():
14
  s = Safety()
15
  sql = """
 
21
  r = s.check(sql)
22
  assert r.ok
23
 
24
+
25
  def test_safety_allows_select_with_comments_and_newlines():
26
  s = Safety()
27
  sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
28
  r = s.check(sql)
29
  assert r.ok
30
 
31
+
32
  def test_safety_allows_keywords_inside_string_literals():
33
  s = Safety()
34
  sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
 
42
  assert not result.ok
43
  assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
44
 
45
+
46
+ @pytest.mark.parametrize(
47
+ "sql",
48
+ [
49
+ "UPDATE users SET name='X' WHERE id=1;",
50
+ "INSERT INTO users(id) VALUES (1);",
51
+ "DROP TABLE users;",
52
+ "CREATE TABLE x(id INT);",
53
+ "ALTER TABLE users ADD COLUMN x INT;",
54
+ "ATTACH DATABASE 'hack.db' AS h;",
55
+ "PRAGMA journal_mode=WAL;",
56
+ ],
57
+ )
58
  def test_safety_blocks_forbidden_statements(sql):
59
  s = Safety()
60
  res = s.check(sql)
61
  assert not res.ok
62
 
63
+
64
  def test_safety_blocks_stacked_delete_after_select():
65
  s = Safety()
66
  sql = "SELECT * FROM users; DELETE FROM users;"
67
  r = s.check(sql)
68
  assert not r.ok
69
 
70
+
71
  def test_safety_blocks_stacked_delete_with_spaces():
72
  s = Safety()
73
  sql = "SELECT * FROM users ; \n DELETE users;"
74
  r = s.check(sql)
75
  assert not r.ok
76
 
77
+
78
  def test_safety_blocks_delete_inside_cte():
79
  s = Safety()
80
  sql = """
 
84
  r = s.check(sql)
85
  assert not r.ok
86
 
87
+
88
+ @pytest.mark.parametrize(
89
+ "sql",
90
+ [
91
+ "/*D*/ROP TABLE users;",
92
+ "PR/*x*/AGMA journal_mode=WAL;",
93
+ "AL/* comment */TER TABLE x ADD COLUMN y INT;",
94
+ ],
95
+ )
96
  def test_safety_blocks_comment_obfuscation(sql):
97
  s = Safety()
98
  r = s.check(sql)
99
  assert not r.ok
100
 
101
+
102
+ @pytest.mark.parametrize(
103
+ "sql",
104
+ [
105
+ "pragma journal_mode=WAL;", # lower-case
106
+ " PRAGMA user_version = 5 ; ",
107
+ "\nATTACH DATABASE 'hack.db' AS h;",
108
+ ],
109
+ )
110
  def test_safety_blocks_forbidden_case_and_spacing(sql):
111
  s = Safety()
112
  r = s.check(sql)
113
  assert not r.ok
114
 
115
+
116
  def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
117
  s = Safety()
118
  sql = "SELECT 1; -- now do something bad\n"
tests/test_stage_types.py CHANGED
@@ -1,16 +1,19 @@
1
  from nl2sql.types import StageResult, StageTrace
2
 
 
3
  def test_error_response():
4
  r = StageResult(ok=False, error=["Syntax error"])
5
  assert not r.ok
6
  assert r.error == ["Syntax error"]
7
 
 
8
  def test_trace_dataclass_structure():
9
  t = StageTrace(stage="planner", duration_ms=12.5, token_in=10, token_out=20)
10
  assert t.stage == "planner"
11
  assert isinstance(t.duration_ms, float)
12
  assert t.token_out == 20
13
 
 
14
  def test_stage_result_defaults():
15
  r = StageResult(ok=True)
16
  assert r.ok
 
1
  from nl2sql.types import StageResult, StageTrace
2
 
3
+
4
  def test_error_response():
5
  r = StageResult(ok=False, error=["Syntax error"])
6
  assert not r.ok
7
  assert r.error == ["Syntax error"]
8
 
9
+
10
  def test_trace_dataclass_structure():
11
  t = StageTrace(stage="planner", duration_ms=12.5, token_in=10, token_out=20)
12
  assert t.stage == "planner"
13
  assert isinstance(t.duration_ms, float)
14
  assert t.token_out == 20
15
 
16
+
17
  def test_stage_result_defaults():
18
  r = StageResult(ok=True)
19
  assert r.ok
ui/benchmark_app.py CHANGED
@@ -22,8 +22,8 @@ df = pd.DataFrame(rows)
22
  st.subheader("Aggregate Metrics")
23
  col1, col2, col3, col4 = st.columns(4)
24
  col1.metric("Total Queries", len(df))
25
- col2.metric("Execution Accuracy", f"{df['exec_acc'].mean()*100:.1f}%")
26
- col3.metric("Safety Violations", f"{df['safe_fail'].mean()*100:.1f}%")
27
  col4.metric("Average Latency (ms)", f"{df['latency_ms'].mean():.0f}")
28
 
29
  # 3. Latency Distribution
@@ -33,13 +33,23 @@ st.plotly_chart(fig1, use_container_width=True)
33
 
34
  # 4. Cost vs Accuracy
35
  st.subheader("Cost vs Execution Accuracy")
36
- fig2 = px.scatter(df, x="cost_usd", y="exec_acc", color="provider",
37
- title="Trade-off: Cost vs Accuracy", hover_data=["query"])
 
 
 
 
 
 
38
  st.plotly_chart(fig2, use_container_width=True)
39
 
40
  # 5. Repair Stats
41
  if "repair_attempts" in df.columns:
42
  st.subheader("Repair Attempts")
43
- fig3 = px.bar(df.groupby("repair_attempts").size().reset_index(name="count"),
44
- x="repair_attempts", y="count", title="Number of Repair Attempts per Query")
 
 
 
 
45
  st.plotly_chart(fig3, use_container_width=True)
 
22
  st.subheader("Aggregate Metrics")
23
  col1, col2, col3, col4 = st.columns(4)
24
  col1.metric("Total Queries", len(df))
25
+ col2.metric("Execution Accuracy", f"{df['exec_acc'].mean() * 100:.1f}%")
26
+ col3.metric("Safety Violations", f"{df['safe_fail'].mean() * 100:.1f}%")
27
  col4.metric("Average Latency (ms)", f"{df['latency_ms'].mean():.0f}")
28
 
29
  # 3. Latency Distribution
 
33
 
34
  # 4. Cost vs Accuracy
35
  st.subheader("Cost vs Execution Accuracy")
36
+ fig2 = px.scatter(
37
+ df,
38
+ x="cost_usd",
39
+ y="exec_acc",
40
+ color="provider",
41
+ title="Trade-off: Cost vs Accuracy",
42
+ hover_data=["query"],
43
+ )
44
  st.plotly_chart(fig2, use_container_width=True)
45
 
46
  # 5. Repair Stats
47
  if "repair_attempts" in df.columns:
48
  st.subheader("Repair Attempts")
49
+ fig3 = px.bar(
50
+ df.groupby("repair_attempts").size().reset_index(name="count"),
51
+ x="repair_attempts",
52
+ y="count",
53
+ title="Number of Repair Attempts per Query",
54
+ )
55
  st.plotly_chart(fig3, use_container_width=True)