Melika Kheirieh commited on
Commit
570f7bd
·
1 Parent(s): 5eeca35

init: NL2SQL Copilot base with API and Dockerfile

Browse files
.github/workflows/ci.yml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [ main, develop ]
6
+ pull_request:
7
+
8
+ jobs:
9
+ build-test:
10
+ runs-on: ubuntu-latest
11
+
12
+ env:
13
+ PIP_NO_CACHE_DIR: 1
14
+
15
+ steps:
16
+ - name: Checkout repository
17
+ uses: actions/checkout@v4
18
+
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v5
21
+ with:
22
+ python-version: "3.12"
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ pip install -r requirements.txt
28
+
29
+ - name: Lint (ruff)
30
+ run: ruff check .
31
+
32
+ - name: Type check (mypy)
33
+ run: mypy .
34
+
35
+ - name: Run tests
36
+ run: pytest -q
37
+
38
+ docker-build:
39
+ needs: build-test
40
+ runs-on: ubuntu-latest
41
+ if: github.ref == 'refs/heads/main'
42
+
43
+ steps:
44
+ - name: Checkout code
45
+ uses: actions/checkout@v4
46
+
47
+ - name: Login to GHCR
48
+ if: secrets.GHCR_TOKEN != ''
49
+ run: echo "${{ secrets.GHCR_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin
50
+
51
+ - name: Build Docker image
52
+ run: |
53
+ IMAGE=ghcr.io/${{ github.repository_owner }}/nl2sql-copilot:${{ github.sha }}
54
+ docker build -t $IMAGE .
55
+ echo "IMAGE=$IMAGE" >> $GITHUB_ENV
56
+
57
+ - name: Push image
58
+ if: secrets.GHCR_TOKEN != ''
59
+ run: docker push $IMAGE
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------- Stage 1: Build wheels ----------
2
+ FROM python:3.12-slim AS builder
3
+
4
+ # Set working directory for the build stage
5
+ WORKDIR /build
6
+
7
+ # Install system dependencies required to compile some Python packages
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential libpq-dev && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy only requirements first (so Docker caching works efficiently)
13
+ COPY requirements.txt .
14
+
15
+ # Build all dependencies as wheel files inside /wheels
16
+ RUN pip install --upgrade pip && \
17
+ pip wheel --wheel-dir /wheels -r requirements.txt
18
+
19
+
20
+ # ---------- Stage 2: Runtime image ----------
21
+ FROM python:3.12-slim AS runtime
22
+
23
+ # Set working directory for the application
24
+ WORKDIR /app
25
+
26
+ # Copy prebuilt wheels from the builder stage
27
+ COPY --from=builder /wheels /wheels
28
+
29
+ # Install dependencies from prebuilt wheels (no need to compile again)
30
+ COPY requirements.txt .
31
+ RUN pip install --no-cache-dir --find-links=/wheels -r requirements.txt
32
+
33
+ # Copy the actual application code
34
+ COPY . .
35
+
36
+ # Expose the FastAPI port
37
+ EXPOSE 8000
38
+
39
+ # Start FastAPI with Uvicorn
40
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers"]
adapters/db/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Any, Protocol
2
+ from typing import List, Tuple, Any
3
+
4
+ class DBAdapter(Protocol):
5
+ """Abstract database adapter for read-only queries."""
6
+ name: str
7
+ dialect: str
8
+
9
+ def preview_schema(self, limit_per_table: int = 0) -> str:
10
+ """Generate a readable summary of the database schema with optional sample rows per table."""
11
+
12
+ def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
13
+ """Execute a SELECT query and return (rows, columns)."""
adapters/db/postgres_adapter.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg
2
+ from typing import Any, List, Tuple
3
+ from adapters.db.base import DBAdapter
4
+
5
+ class PostgresAdapter(DBAdapter):
6
+ name = "postgres"
7
+ dialect = "postgres"
8
+
9
+ def __init__(self, dsn: str):
10
+ """
11
+ DSN example:
12
+ "dbname=demo user=postgres password=postgres host=localhost port=5432"
13
+ """
14
+ self.dsn = dsn
15
+
16
+ def preview_schema(self, limit_per_table: int = 0) -> str:
17
+ with psycopg.connect(self.dsn) as conn:
18
+ cur = conn.cursor()
19
+ cur.execute("""
20
+ SELECT table_name
21
+ FROM information_schema.tables
22
+ WHERE table_schema = 'public';
23
+ """)
24
+ tables = [t[0] for t in cur.fetchall()]
25
+ lines = []
26
+ for t in tables:
27
+ cur.execute(f"""
28
+ SELECT column_name, data_type
29
+ FROM information_schema.columns
30
+ WHERE table_name = %s;
31
+ """, (t,))
32
+ cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
33
+ lines.append(f"- {t} ({', '.join(cols)})")
34
+ return "\n".join(lines)
35
+
36
+ def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
37
+ if not sql.strip().lower().startswith("select"):
38
+ raise ValueError("Only SELECT statements are allowed.")
39
+ with psycopg.connect(self.dsn) as conn:
40
+ cur = conn.cursor()
41
+ cur.execute(sql)
42
+ rows = cur.fetchall()
43
+ cols = [desc[0] for desc in cur.description]
44
+ return rows, cols
adapters/db/sqlite_adapter.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import List, Tuple, Any
3
+ from adapters.db.base import DBAdapter
4
+
5
+ class SQLiteAdapter(DBAdapter):
6
+ name = "sqlite"
7
+ dialect = "sqlite"
8
+
9
+ def __init__(self, path: str):
10
+ self.path = path
11
+
12
+ def preview_schema(self, limit_per_table: int = 0) -> str:
13
+ with sqlite3.connect(self.path, uri=True) as conn:
14
+ cur = conn.cursor()
15
+ cur.execute("PRAGMA foreign_keys = ON")
16
+ tables = [t[0] for t in cur.fetchall()]
17
+ lines = []
18
+ for t in tables:
19
+ cur.execute(f"PRAGMA table_info({t});")
20
+ cols = [f"{c[1]}:{c[2]}" for c in cur.fetchall()]
21
+ lines.append(f"- {t} ({', '.join(cols)})")
22
+ return "\n".join(lines)
23
+
24
+ def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
25
+ # enforce read-only connection
26
+ uri = f"file:{self.path}?mode=ro&uri=true"
27
+ with sqlite3.connect(uri, uri=True, timeout=3) as conn:
28
+ cur = conn.cursor()
29
+ cur.execute(sql)
30
+ rows = cur.fetchall()
31
+ cols = [desc[0] for desc in cur.description]
32
+ return rows, cols
adapters/llm/base.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapters/llm/base.py
2
+ from __future__ import annotations
3
+ from typing import Tuple, List, Dict, Any, Protocol
4
+
5
+ class LLMProvider(Protocol):
6
+ provider_id: str
7
+
8
+ def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
9
+ """Return (plan_text, token_in, token_out, cost_usd)."""
10
+
11
+ def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str,
12
+ clarify_answers: Dict[str, Any] | None = None) -> Tuple[str, str, int, int, float]:
13
+ """Return (sql, rationale, token_in, token_out, cost_usd)."""
14
+
15
+ def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
16
+ """Return (patched_sql, token_in, token_out, cost_usd)."""
adapters/llm/openai_provider.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import Tuple, Dict, Any, List
4
+ import json
5
+ from adapters.llm.base import LLMProvider
6
+ from openai import OpenAI
7
+
8
+ # NOTE: Read keys/base URL from env. Do NOT pass base_url in constructors.
9
+ # - OPENAI_API_KEY (required)
10
+ # - OPENAI_BASE_URL (optional; defaults to OpenAI public)
11
+ # - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
12
+
13
+
14
+
15
+ class OpenAIProvider(LLMProvider):
16
+ provider_id = "openai"
17
+
18
+ def __init__(self) -> None:
19
+ self.client = OpenAI(
20
+ api_key=os.environ["OPENAI_API_KEY"],
21
+ base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
22
+ )
23
+ self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
24
+
25
+ def plan(self, *, user_query, schema_preview):
26
+ completion = self.client.chat.completions.create(
27
+ model=self.model,
28
+ messages=[
29
+ {"role": "system", "content": "You create SQL query plans."},
30
+ {"role": "user", "content": f"Query: {user_query}\nSchema:\n{schema_preview}"}
31
+ ],
32
+ temperature=0
33
+ )
34
+ msg = completion.choices[0].message.content
35
+ usage = completion.usage
36
+ return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
37
+
38
+
39
+ def generate_sql(self, *, user_query, schema_preview, plan_text, clarify_answers=None):
40
+ prompt = f"""
41
+ You are a precise SQL generator.
42
+ Return ONLY valid JSON with two keys: "sql" and "rationale".
43
+ Do not include any markdown, backticks, or extra text.
44
+
45
+ Example:
46
+ {{
47
+ "sql": "SELECT * FROM singer;",
48
+ "rationale": "The user requested to list all singers."
49
+ }}
50
+
51
+ Now generate JSON for this input:
52
+
53
+ User query: {user_query}
54
+ Schema preview:
55
+ {schema_preview}
56
+ Plan: {plan_text}
57
+ Clarifications: {clarify_answers}
58
+ """
59
+ completion = self.client.chat.completions.create(
60
+ model=self.model,
61
+ messages=[
62
+ {"role": "system", "content": "You convert natural language to SQL."},
63
+ {"role": "user", "content": prompt}
64
+ ],
65
+ temperature=0
66
+ )
67
+ content = completion.choices[0].message.content.strip()
68
+ usage = completion.usage # ← لازم داریم
69
+ t_in = usage.prompt_tokens if usage else None
70
+ t_out = usage.completion_tokens if usage else None
71
+ cost = self._estimate_cost(usage) if usage else None
72
+
73
+ # Robust JSON parse (with fallback to substring)
74
+ try:
75
+ parsed = json.loads(content)
76
+ except json.JSONDecodeError:
77
+ start = content.find("{")
78
+ end = content.rfind("}")
79
+ if start != -1 and end != -1:
80
+ try:
81
+ parsed = json.loads(content[start:end + 1])
82
+ except Exception:
83
+ raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
84
+ else:
85
+ raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
86
+
87
+ sql = (parsed.get("sql") or "").strip()
88
+ rationale = parsed.get("rationale") or ""
89
+
90
+ if not sql:
91
+ raise ValueError("LLM returned empty 'sql'")
92
+
93
+ # IMPORTANT: return the expected 5-tuple
94
+ return sql, rationale, t_in, t_out, cost
95
+
96
+
97
+ def repair(self, *, sql, error_msg, schema_preview):
98
+ completion = self.client.chat.completions.create(
99
+ model=self.model,
100
+ messages=[
101
+ {"role": "system", "content": "You fix SQL queries keeping them SELECT-only."},
102
+ {"role": "user", "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}"}
103
+ ],
104
+ temperature=0
105
+ )
106
+ msg = completion.choices[0].message.content
107
+ usage = completion.usage
108
+ return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
109
+
110
+ def _estimate_cost(self, usage):
111
+ # Rough estimation example — can be refined with official token pricing
112
+ total = usage.prompt_tokens + usage.completion_tokens
113
+ return total * 0.000001
app.py DELETED
@@ -1,235 +0,0 @@
1
- from config import (
2
- LLM_MODEL,
3
- LLM_TEMPERATURE,
4
- FORBIDDEN_KEYWORDS,
5
- FORBIDDEN_TABLES
6
- )
7
- import os
8
- import sqlite3
9
- import json
10
- import re
11
- from typing import Optional, Tuple, List
12
-
13
- import gradio as gr
14
- import sqlglot
15
- from sqlglot import exp
16
-
17
- from langchain_openai import ChatOpenAI
18
- from langchain_community.utilities import SQLDatabase
19
- from langchain.chains import create_sql_query_chain
20
- from langchain.prompts import ChatPromptTemplate
21
-
22
-
23
- def get_readonly_sqlite_url(db_path: str) -> str:
24
- return f"file:{db_path}?mode=ro&uri=true"
25
-
26
- def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str:
27
- uri = get_readonly_sqlite_url(db_path)
28
- with sqlite3.connect(uri, uri=True, timeout=3) as conn:
29
- conn.row_factory = sqlite3.Row
30
- cur = conn.cursor()
31
- cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
32
- tables = [r["name"] for r in cur.fetchall()]
33
- lines = []
34
- for t in tables:
35
- # skip SQLite internals
36
- if t in FORBIDDEN_TABLES:
37
- continue
38
- cur.execute(f"PRAGMA table_info({t});")
39
- cols = cur.fetchall()
40
- col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols])
41
- lines.append(f"- {t} ({col_line})")
42
- if limit_per_table > 0:
43
- try:
44
- cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};")
45
- sample = cur.fetchall()
46
- if sample:
47
- lines.append(f" sample rows: {len(sample)}")
48
- except Exception:
49
- pass
50
- if not lines:
51
- return "(no user tables found)"
52
- return "\n".join(lines)
53
-
54
-
55
- def validate_sql_safe(sql: str) -> Tuple[bool, str]:
56
- if sql.count(";") > 0:
57
- if sql.strip().endswith(";"):
58
- if sql.strip()[:-1].count(";") > 0:
59
- return False, "Multiple statements are not allowed."
60
- else:
61
- return False, "Multiple statements are not allowed."
62
-
63
- upper = re.sub(r"\s+", " ", sql).strip()
64
- for kw in FORBIDDEN_KEYWORDS:
65
- if re.search(rf"\b{kw}\b", upper):
66
- return False, f"Keyword '{kw}' is not allowed."
67
-
68
- try:
69
- parsed = sqlglot.parse(sql, read='sqlite')
70
- except Exception as e:
71
- return False, f"SQL parse error: {e}"
72
-
73
- if not parsed or len(parsed) != 1:
74
- return False, "Exactly one SQL statement is allowed."
75
-
76
- stmt = parsed[0]
77
- if not isinstance(stmt, exp.Select):
78
- return False, "Only SELECT statements are allowed."
79
-
80
- for table in stmt.find_all(exp.Table):
81
- table_name = table.name.lower() if table.name else ""
82
- if table_name in FORBIDDEN_TABLES:
83
- return False, f"Access to {table_name} is not allowed."
84
-
85
- return True, "OK"
86
-
87
- def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]:
88
- uri = get_readonly_sqlite_url(db_path)
89
- if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE):
90
- sql = f"{sql.rstrip(';')} LIMIT {max_rows}"
91
-
92
- with sqlite3.connect(uri, uri=True, timeout=timeout) as conn:
93
- conn.row_factory = sqlite3.Row
94
- cur = conn.cursor()
95
- cur.execute(sql)
96
- rows = cur.fetchall()
97
- if rows:
98
- cols = rows[0].keys()
99
- data = [list(r) for r in rows]
100
- return list(cols), data
101
- else:
102
- return [], []
103
-
104
-
105
-
106
- custom_prompt = ChatPromptTemplate.from_template("""
107
- Given the following question, return ONLY a valid SQL query in JSON form.
108
-
109
- Question: {input}
110
- Database schema: {table_info}
111
-
112
- You may sample/preview at most {top_k} rows if you need examples.
113
-
114
- Respond in this exact JSON format:
115
- {{
116
- "sql": "<SQL_QUERY_HERE>"
117
- }}
118
- """)
119
-
120
-
121
- def make_sql_chain(sql_db: SQLDatabase):
122
- assert hasattr(sql_db, "get_table_info"), "Expected LangChain SQLDatabase"
123
- llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
124
- chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
125
- return chain
126
-
127
-
128
- def on_upload_database(db_file, state):
129
- if db_file is None:
130
- return state, "No file provided.", "(no schema)"
131
- path = db_file.name
132
-
133
- sql_db = SQLDatabase.from_uri(f"sqlite:///{path}")
134
-
135
- schema_text = get_schema_preview(path, limit_per_table=0)
136
-
137
- chain = make_sql_chain(sql_db)
138
-
139
- new_state = {
140
- "db_path": path,
141
- "sql_db": sql_db,
142
- "schema_text": schema_text,
143
- "chain": chain,
144
- }
145
- return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text
146
-
147
- def extract_sql_safe(output_text: str) -> str:
148
- try:
149
- obj = json.loads(output_text)
150
- if isinstance(obj, dict) and "sql" in obj:
151
- return obj["sql"].strip()
152
- except Exception:
153
- pass
154
- m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE)
155
- if m:
156
- return m.group(1).strip()
157
- return output_text.strip()
158
-
159
- def on_generate_query(question , max_rows, state):
160
- if not state or not state.get("db_path") or not state.get("chain"):
161
- return "Please upload a database first.", "", ""
162
- if not question or not question.strip():
163
- return "Please enter a question.", "", ""
164
-
165
- try:
166
- generated_sql = state["chain"].invoke({"question": question})
167
-
168
- sql = extract_sql_safe(str(generated_sql))
169
-
170
- ok, msg = validate_sql_safe(sql)
171
- if not ok:
172
- return f"Blocked SQL: {msg}", sql, ""
173
-
174
- cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows)
175
- if not cols:
176
- return f"No rows returned.", sql, "[]"
177
-
178
- sample = [dict(zip(cols, r)) for r in rows[:50]]
179
- return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2)
180
-
181
- except Exception as e:
182
- return f"Error: {e}", "", ""
183
-
184
-
185
- with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo:
186
- gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)")
187
- gr.Markdown(
188
- "Upload a **SQLite** file, ask a question in natural language, "
189
- "and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, "
190
- "and (4) show you the results."
191
- )
192
-
193
- state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None})
194
-
195
- with gr.Row():
196
- db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"])
197
- upload_status = gr.Textbox(label="upload Status", interactive=False)
198
-
199
- schema_box = gr.Accordion("Database schema (preview)", open=False)
200
- with schema_box:
201
- schema_md = gr.Markdown("(no schema)")
202
-
203
- gr.Markdown("---")
204
-
205
- with gr.Row():
206
- question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales")
207
- with gr.Row():
208
- max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows")
209
-
210
- with gr.Row():
211
- run_btn = gr.Button("Generate & Run SQL", variant="primary")
212
-
213
- with gr.Row():
214
- status_out = gr.Textbox(label="Status")
215
- with gr.Row():
216
- sql_out = gr.Code(label="Generated SQL (validated)")
217
- with gr.Row():
218
- result_out = gr.Code(label="Result (JSON sample)")
219
-
220
- db_file.change(
221
- fn=on_upload_database,
222
- inputs=[db_file, state],
223
- outputs=[state, upload_status, schema_md],
224
- )
225
-
226
- run_btn.click(
227
- fn=on_generate_query,
228
- inputs=[question, max_row, state],
229
- outputs=[status_out, sql_out, result_out],
230
- )
231
-
232
-
233
-
234
- if __name__ == "__main__":
235
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
+
4
+ from fastapi import FastAPI
5
+ from app.routers import nl2sql
6
+ app = FastAPI(
7
+ title="NL2SQL Copilot Prototype",
8
+ version="0.1.0",
9
+ description="Natural Language -> SQL Copilot API"
10
+ )
11
+
12
+ app.include_router(nl2sql.router, prefix="/api/v1")
13
+
14
+ @app.get("/healthz")
15
+ def health_check():
16
+ return {"status": "ok"}
17
+
18
+ @app.get("/")
19
+ def root():
20
+ return {"status": "ok", "message": "NL2SQL Copilot API is running"}
21
+
22
+ @app.get("/health")
23
+ def health():
24
+ return {
25
+ "status": "ok",
26
+ "db": "connected",
27
+ "llm": "reachable",
28
+ "uptime_sec": 123.4
29
+ }
app/routers/__init__.py ADDED
File without changes
app/routers/nl2sql.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, is_dataclass
2
+ from fastapi import APIRouter, HTTPException
3
+ from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
4
+ from nl2sql.pipeline import Pipeline
5
+ from nl2sql.ambiguity_detector import AmbiguityDetector
6
+ from nl2sql.safety import Safety
7
+ from nl2sql.planner import Planner
8
+ from nl2sql.generator import Generator
9
+ from adapters.llm.openai_provider import OpenAIProvider
10
+ from nl2sql.types import StageResult
11
+ from nl2sql.executor import Executor
12
+ from nl2sql.verifier import Verifier
13
+ from nl2sql.repair import Repair
14
+ from adapters.db.sqlite_adapter import SQLiteAdapter
15
+ from adapters.db.postgres_adapter import PostgresAdapter
16
+ import os
17
+
18
+
19
+ router = APIRouter(prefix="/nl2sql")
20
+
21
+
22
+
23
+ if os.getenv("DB_MODE", "sqlite") == "postgres":
24
+ _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
25
+ else:
26
+ _db = SQLiteAdapter("data/chinook.db")
27
+
28
+ # --- Composition Root ---
29
+ _llm = OpenAIProvider()
30
+ # _db = SQLiteAdapter("data/chinook.db")
31
+ _executor = Executor(_db)
32
+ _verifier = Verifier()
33
+ _repair = Repair(_llm)
34
+
35
+
36
+ _pipeline = Pipeline(
37
+ detector=AmbiguityDetector(),
38
+ planner=Planner(_llm),
39
+ generator=Generator(_llm),
40
+ safety=Safety(),
41
+ executor=_executor,
42
+ verifier=_verifier,
43
+ repair=_repair
44
+ )
45
+
46
+
47
+ def _to_dict(obj):
48
+ """Helper: safely convert dataclass → dict."""
49
+ return asdict(obj) if is_dataclass(obj) else obj
50
+
51
+ def _round_trace(t: dict) -> dict:
52
+ if t.get("cost_usd") is not None:
53
+ t["cost_usd"] = round(t["cost_usd"], 6)
54
+ if t.get("duration_ms") is not None:
55
+ t["duration_ms"] = round(t["duration_ms"], 2)
56
+ return t
57
+
58
+ @router.post("", name="nl2sql_handler")
59
+ def nl2sql_handler(request: NL2SQLRequest):
60
+ result = _pipeline.run(user_query=request.query, schema_preview=request.schema_preview)
61
+
62
+ # --- Ensure result type ---
63
+ if not isinstance(result, StageResult):
64
+ raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
65
+
66
+ data = result.data or {}
67
+
68
+ # --- Handle ambiguity ---
69
+ if isinstance(data, dict) and data.get("ambiguous") and data.get("questions"):
70
+ return ClarifyResponse(ambiguous=True, questions=data["questions"])
71
+
72
+ # --- Handle error ---
73
+ if not result.ok:
74
+ detail = "; ".join(result.error) if result.error else "Unknown error"
75
+ raise HTTPException(status_code=400, detail=detail)
76
+
77
+ # --- Success case ---
78
+ return NL2SQLResponse(
79
+ ambiguous=False,
80
+ sql=data.get("sql"),
81
+ rationale=data.get("rationale"),
82
+ traces=[_to_dict(t) for t in data.get("traces", [])],
83
+ )
app/schemas.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional, Any, Dict
3
+
4
+ class NL2SQLRequest(BaseModel):
5
+ query: str
6
+ schema_preview: str
7
+ db_name: Optional[str] = "default"
8
+
9
+ class TraceModel(BaseModel):
10
+ stage: str
11
+ duration_ms: float
12
+ token_in: int | None = 0
13
+ token_out: int | None = 0
14
+ cost_usd: float | None = 0
15
+ notes: Dict[str, Any] | None = None
16
+
17
+ class NL2SQLResponse(BaseModel):
18
+ ambiguous: bool = False
19
+ sql: str
20
+ rationale: Optional[str] = None
21
+ traces: List[TraceModel] = []
22
+
23
+ class ClarifyResponse(BaseModel):
24
+ ambiguous: bool = True
25
+ questions: List[str]
26
+
27
+ class ErrorResponse(BaseModel):
28
+ error: str
29
+ details: List[str] | None = None
benchmarks/results/demo.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"query": "show all users", "exec_acc": 0.0, "safe_fail": 0.0, "latency_ms": 0.610041999607347, "cost_usd": 0.0, "repair_attempts": 0, "provider": "dummy-llm"}
2
+ {"query": "top spenders", "exec_acc": 0.0, "safe_fail": 0.0, "latency_ms": 0.005625000085274223, "cost_usd": 0.0, "repair_attempts": 0, "provider": "dummy-llm"}
3
+ {"query": "sum of spend", "exec_acc": 0.0, "safe_fail": 0.0, "latency_ms": 0.20833300004596822, "cost_usd": 0.0, "repair_attempts": 0, "provider": "dummy-llm"}
benchmarks/run.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # benchmarks/run.py
2
+ from __future__ import annotations
3
+ import argparse
4
+ import os
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+
9
+ # ---- app imports
10
+ from nl2sql.pipeline import Pipeline
11
+ from nl2sql.ambiguity_detector import AmbiguityDetector
12
+ from nl2sql.planner import Planner
13
+ from nl2sql.generator import Generator
14
+ from nl2sql.safety import Safety
15
+ from nl2sql.executor import Executor
16
+ from nl2sql.verifier import Verifier
17
+ from nl2sql.repair import Repair
18
+
19
+ # ---- adapters
20
+ from adapters.db.sqlite_adapter import SQLiteAdapter
21
+ from adapters.llm.openai_provider import OpenAIProvider
22
+
23
+ # ---- fallbacks: Dummy LLM (so it runs without API keys)
24
+ class DummyLLM:
25
+ provider_id = "dummy-llm"
26
+
27
+ def plan(self, *, user_query: str, schema_preview: str):
28
+ text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
29
+ return text, 0, 0, 0.0
30
+
31
+ def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str, clarify_answers=None):
32
+ # naive demo SQL (so pipeline flows end-to-end)
33
+ sql = "SELECT 1 AS one;"
34
+ rationale = "Demo SQL from DummyLLM"
35
+ return sql, rationale, 0, 0, 0.0
36
+
37
+ def repair(self, *, sql: str, error_msg: str, schema_preview: str):
38
+ return sql, 0, 0, 0.0
39
+
40
+
41
+ def ensure_demo_db(path: Path) -> None:
42
+ """Create a tiny SQLite db if missing, so executor has something to run."""
43
+ if path.exists():
44
+ return
45
+ import sqlite3
46
+ path.parent.mkdir(parents=True, exist_ok=True)
47
+ con = sqlite3.connect(path)
48
+ cur = con.cursor()
49
+ cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
50
+ cur.executemany("INSERT INTO users(id,name,spend) VALUES(?,?,?)",
51
+ [(1,"Alice",120.5),(2,"Bob",80.0),(3,"Carol",155.0)])
52
+ con.commit()
53
+ con.close()
54
+
55
+
56
+ def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
57
+ # DB adapter
58
+ db = SQLiteAdapter(str(db_path))
59
+ executor = Executor(db)
60
+ # LLM provider
61
+ if use_openai and os.getenv("OPENAI_API_KEY"):
62
+ llm = OpenAIProvider()
63
+ else:
64
+ llm = DummyLLM()
65
+ # stages
66
+ detector = AmbiguityDetector()
67
+ planner = Planner(llm)
68
+ generator = Generator(llm)
69
+ safety = Safety()
70
+ verifier = Verifier()
71
+ repair = Repair(llm)
72
+ # pipeline
73
+ return Pipeline(
74
+ detector=detector,
75
+ planner=planner,
76
+ generator=generator,
77
+ safety=safety,
78
+ executor=executor,
79
+ verifier=verifier,
80
+ repair=repair,
81
+ )
82
+
83
+
84
+ def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
85
+ results = []
86
+ for q in queries:
87
+ t0 = time.perf_counter()
88
+ r = pipeline.run(user_query=q, schema_preview=schema_preview)
89
+ latency_ms = (time.perf_counter()-t0)*1000
90
+ ok = (not r.get("ambiguous")) and ("error" not in r)
91
+
92
+ traces = r.get("traces", [])
93
+ cost_sum = 0.0
94
+ for t in traces:
95
+ try:
96
+ cost_sum += float(t.get("cost_usd", 0.0))
97
+ except Exception:
98
+ pass
99
+
100
+ results.append({
101
+ "query": q,
102
+ "exec_acc": 1.0 if ok else 0.0,
103
+ "safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
104
+ "latency_ms": latency_ms,
105
+ "cost_usd": cost_sum,
106
+ "repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
107
+ "provider": pipeline.generator.llm.provider_id if hasattr(pipeline.generator, "llm") else "unknown",
108
+ })
109
+
110
+ outfile.parent.mkdir(parents=True, exist_ok=True)
111
+ with open(outfile, "w") as f:
112
+ for row in results:
113
+ f.write(json.dumps(row) + "\n")
114
+ print(f"[OK] wrote {len(results)} rows → {outfile}")
115
+
116
+
117
+ def main():
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
120
+ parser.add_argument("--db", default="data/bench_demo.db")
121
+ parser.add_argument("--use-openai", action="store_true", help="Use OpenAI provider if API key present")
122
+ args = parser.parse_args()
123
+
124
+ ROOT = Path(__file__).resolve().parents[1] # project root
125
+ outfile = (ROOT / args.outfile).resolve()
126
+ db_path = (ROOT / args.db).resolve()
127
+
128
+ ensure_demo_db(db_path)
129
+ pipe = build_pipeline(db_path, use_openai=args.use_openai)
130
+
131
+ # a small demo set; replace with Spider when ready
132
+ queries = [
133
+ "show all users",
134
+ "top spenders",
135
+ "sum of spend",
136
+ ]
137
+ schema_preview = "CREATE TABLE users(id INT, name TEXT, spend REAL);"
138
+
139
+ run_benchmark(queries, schema_preview, pipe, outfile)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
docker-compose.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+
3
+ services:
4
+ postgres:
5
+ image: postgres:16
6
+ container_name: nl2sql_pg
7
+ environment:
8
+ POSTGRES_USER: postgres
9
+ POSTGRES_PASSWORD: postgres
10
+ POSTGRES_DB: demo
11
+ volumes:
12
+ - pgdata:/var/lib/postgresql/data
13
+ - ./infra/migrate.sql:/docker-entrypoint-initdb.d/00_init.sql:ro
14
+ ports:
15
+ - "5432:5432"
16
+ healthcheck:
17
+ test: ["CMD-SHELL", "pg_isready -U postgres -d demo"]
18
+ interval: 5s
19
+ timeout: 3s
20
+ retries: 10
21
+
22
+ api:
23
+ build:
24
+ context: .
25
+ dockerfile: Dockerfile
26
+ container_name: nl2sql_api
27
+ depends_on:
28
+ postgres:
29
+ condition: service_healthy
30
+ environment:
31
+ DB_MODE: postgres
32
+ POSTGRES_DSN: dbname=demo user=postgres password=postgres host=postgres port=5432
33
+ OPENAI_MODEL_ID: gpt-4o-mini
34
+ OPENAI_API_KEY: ${OPENAI_API_KEY}
35
+ ports:
36
+ - "8000:8000"
37
+ command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers"]
38
+
39
+ volumes:
40
+ pgdata:
infra/migrate.sql ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ CREATE TABLE IF NOT EXISTS users (
2
+ id SERIAL PRIMARY KEY,
3
+ name TEXT NOT NULL,
4
+ city TEXT
5
+ );
6
+
7
+ INSERT INTO users (name, city)
8
+ VALUES ('Alice', 'Tehran'), ('Bob', 'Karaj'), ('Caro', 'Isfahan');
logs/spider_eval/dev_gold_1760430884.txt DELETED
@@ -1,10 +0,0 @@
1
- SELECT count(*) FROM singer concert_singer
2
- SELECT count(*) FROM singer concert_singer
3
- SELECT name , country , age FROM singer ORDER BY age DESC concert_singer
4
- SELECT name , country , age FROM singer ORDER BY age DESC concert_singer
5
- SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France' concert_singer
6
- SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France' concert_singer
7
- SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1 concert_singer
8
- SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1 concert_singer
9
- SELECT DISTINCT country FROM singer WHERE age > 20 concert_singer
10
- SELECT DISTINCT country FROM singer WHERE age > 20 concert_singer
 
 
 
 
 
 
 
 
 
 
 
logs/spider_eval/dev_metrics_1760430884.json DELETED
@@ -1,15 +0,0 @@
1
- {
2
- "commit_hash": "e207f417ac5923220817e3c3f61c72e51a98c63b",
3
- "split": "dev",
4
- "limit": 10,
5
- "total_examples": 10,
6
- "valid_examples": 10,
7
- "exact_match_rate": 0.2,
8
- "exact_match_structural_rate": 0.0,
9
- "execution_accuracy_rate": 0.8,
10
- "error_rate": 0.0,
11
- "safe_check_fail_rate": 0.0,
12
- "avg_gen_time": 1.4374850749969483,
13
- "avg_exec_time": 0.0007865667343139648,
14
- "run_id": 1760430884
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
logs/spider_eval/dev_pred_1760430884.txt DELETED
@@ -1,10 +0,0 @@
1
- SELECT COUNT(*) AS total_singers FROM singer; concert_singer
2
- SELECT COUNT(*) AS total_singers FROM singer; concert_singer
3
- SELECT Name, Country, Age FROM singer ORDER BY Age DESC concert_singer
4
- SELECT Name, Country, Age FROM singer ORDER BY Age DESC concert_singer
5
- SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France' concert_singer
6
- SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France'; concert_singer
7
- SELECT Name, Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer) concert_singer
8
- SELECT Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer) concert_singer
9
- SELECT DISTINCT Country FROM singer WHERE Age > 20 concert_singer
10
- SELECT DISTINCT Country FROM singer WHERE Age > 20 concert_singer
 
 
 
 
 
 
 
 
 
 
 
logs/spider_eval/dev_results_1760430884.jsonl DELETED
@@ -1,11 +0,0 @@
1
- # {"commit_hash": "e207f417ac5923220817e3c3f61c72e51a98c63b", "split": "dev", "limit": 10, "start_time": 1760430884}
2
- {"db_id": "concert_singer", "question": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "SELECT COUNT(*) AS total_singers FROM singer;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"total_singers\": 6\n }\n]", "gen_time": 1.2182981967926025, "exec_time": 0.0008916854858398438, "error": null, "gold_error": null, "pred_rows": "[(6,)]", "gold_rows": "[(6,)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
3
- {"db_id": "concert_singer", "question": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "SELECT COUNT(*) AS total_singers FROM singer;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"total_singers\": 6\n }\n]", "gen_time": 1.261944055557251, "exec_time": 0.00044798851013183594, "error": null, "gold_error": null, "pred_rows": "[(6,)]", "gold_rows": "[(6,)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
4
- {"db_id": "concert_singer", "question": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "SELECT Name, Country, Age FROM singer ORDER BY Age DESC", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Country\": \"Netherlands\",\n \"Age\": 52\n },\n {\n \"Name\": \"John Nizinik\",\n \"Country\": \"France\",\n \"Age\": 43\n },\n {\n \"Name\": \"Rose White\",\n \"Country\": \"France\",\n \"Age\": 41\n },\n {\n \"Name\": \"Timbaland\",\n \"Country\": \"United States\",\n \"Age\": 32\n },\n {\n \"Name\": \"Justin Brown\",\n \"Country\": \"France\",\n \"Age\": 29\n },\n {\n \"Name\": \"Tribal King\",\n \"Country\": \"France\",\n \"Age\": 25\n }\n]", "gen_time": 1.0276496410369873, "exec_time": 0.0006437301635742188, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "gold_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
5
- {"db_id": "concert_singer", "question": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "SELECT Name, Country, Age FROM singer ORDER BY Age DESC", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Country\": \"Netherlands\",\n \"Age\": 52\n },\n {\n \"Name\": \"John Nizinik\",\n \"Country\": \"France\",\n \"Age\": 43\n },\n {\n \"Name\": \"Rose White\",\n \"Country\": \"France\",\n \"Age\": 41\n },\n {\n \"Name\": \"Timbaland\",\n \"Country\": \"United States\",\n \"Age\": 32\n },\n {\n \"Name\": \"Justin Brown\",\n \"Country\": \"France\",\n \"Age\": 29\n },\n {\n \"Name\": \"Tribal King\",\n \"Country\": \"France\",\n \"Age\": 25\n }\n]", "gen_time": 1.4854280948638916, "exec_time": 0.0007121562957763672, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "gold_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
6
- {"db_id": "concert_singer", "question": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France'", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_age\": 34.5,\n \"minimum_age\": 25,\n \"maximum_age\": 43\n }\n]", "gen_time": 1.8432340621948242, "exec_time": 0.0009491443634033203, "error": null, "gold_error": null, "pred_rows": "[(34.5, 25, 43)]", "gold_rows": "[(34.5, 25, 43)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
7
- {"db_id": "concert_singer", "question": "What is the average, minimum, and maximum age for all French singers?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France';", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_age\": 34.5,\n \"minimum_age\": 25,\n \"maximum_age\": 43\n }\n]", "gen_time": 2.5529589653015137, "exec_time": 0.0012962818145751953, "error": null, "gold_error": null, "pred_rows": "[(34.5, 25, 43)]", "gold_rows": "[(34.5, 25, 43)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
8
- {"db_id": "concert_singer", "question": "Show the name and the release year of the song by the youngest singer.", "gold_sql": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "pred_sql": "SELECT Name, Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer)", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Song_Name\": \"You\",\n \"Song_release_year\": \"1992\"\n }\n]", "gen_time": 1.2428169250488281, "exec_time": 0.0007181167602539062, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'You', '1992')]", "gold_rows": "[('Love', '2016')]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": false, "safe_check_failed": false}
9
- {"db_id": "concert_singer", "question": "What are the names and release years for all the songs of the youngest singer?", "gold_sql": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "pred_sql": "SELECT Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer)", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Song_Name\": \"You\",\n \"Song_release_year\": \"1992\"\n }\n]", "gen_time": 1.4568238258361816, "exec_time": 0.0009098052978515625, "error": null, "gold_error": null, "pred_rows": "[('You', '1992')]", "gold_rows": "[('Love', '2016')]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": false, "safe_check_failed": false}
10
- {"db_id": "concert_singer", "question": "What are all distinct countries where singers above age 20 are from?", "gold_sql": "SELECT DISTINCT country FROM singer WHERE age > 20", "pred_sql": "SELECT DISTINCT Country FROM singer WHERE Age > 20", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Country\": \"Netherlands\"\n },\n {\n \"Country\": \"United States\"\n },\n {\n \"Country\": \"France\"\n }\n]", "gen_time": 0.9801719188690186, "exec_time": 0.0007050037384033203, "error": null, "gold_error": null, "pred_rows": "[('Netherlands',), ('United States',), ('France',)]", "gold_rows": "[('Netherlands',), ('United States',), ('France',)]", "exact_match": true, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
11
- {"db_id": "concert_singer", "question": "What are the different countries with singers above age 20?", "gold_sql": "SELECT DISTINCT country FROM singer WHERE age > 20", "pred_sql": "SELECT DISTINCT Country FROM singer WHERE Age > 20", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Country\": \"Netherlands\"\n },\n {\n \"Country\": \"United States\"\n },\n {\n \"Country\": \"France\"\n }\n]", "gen_time": 1.3055250644683838, "exec_time": 0.0005917549133300781, "error": null, "gold_error": null, "pred_rows": "[('Netherlands',), ('United States',), ('France',)]", "gold_rows": "[('Netherlands',), ('United States',), ('France',)]", "exact_match": true, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
 
 
 
 
 
 
 
 
 
 
 
 
nl2sql/__init__.py ADDED
File without changes
nl2sql/ambiguity_detector.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+
4
+ class AmbiguityDetector:
5
+ """Lightweight AmbiSQL-style ambiguity detection."""
6
+
7
+ AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
8
+
9
+ def detect(self, query:str, schema_preview: str) -> list[str]:
10
+ hits = []
11
+ q_lower = query.lower()
12
+ for term in self.AMBIGUOUS_TERMS:
13
+ if re.search(rf"\b{term}\b", q_lower):
14
+ hits.append(f"The term '{term}' is ambiguous in this query.'")
15
+
16
+ return hits
nl2sql/executor.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from nl2sql.types import StageResult, StageTrace
3
+ from adapters.db.base import DBAdapter
4
+
5
+ class Executor:
6
+ name = "executor"
7
+
8
+ def __init__(self, db: DBAdapter):
9
+ self.db = db
10
+
11
+ def run(self, sql: str) -> StageResult:
12
+ t0 = time.perf_counter()
13
+ try:
14
+ rows, cols = self.db.execute(sql)
15
+ trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
16
+ notes={"row_count": len(rows), "col_count": len(cols)})
17
+ return StageResult(ok=True, data={"rows": rows, "columns": cols}, trace=trace)
18
+ except Exception as e:
19
+ trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
20
+ notes={"error": str(e)})
21
+ return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
nl2sql/generator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import time
3
+ from typing import Optional, Dict, Any
4
+ from nl2sql.types import StageResult, StageTrace
5
+ from adapters.llm.base import LLMProvider
6
+
7
+ class Generator:
8
+ name = "generator"
9
+
10
+ def __init__(self, llm: LLMProvider) -> None:
11
+ self.llm = llm
12
+
13
+ def run(self, *, user_query: str, schema_preview: str, plan_text: str,
14
+ clarify_answers: Optional[Dict[str, Any]] = None) -> StageResult:
15
+ t0 = time.perf_counter()
16
+ try:
17
+ res = self.llm.generate_sql(
18
+ user_query=user_query,
19
+ schema_preview=schema_preview,
20
+ plan_text=plan_text,
21
+ clarify_answers=clarify_answers or {}
22
+ )
23
+ except Exception as e:
24
+ return StageResult(ok=False, error=[f"Generator failed: {e}"])
25
+
26
+ # Expect a 5-tuple
27
+ if not isinstance(res, tuple) or len(res) != 5:
28
+ return StageResult(ok=False, error=["Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"])
29
+
30
+ sql, rationale, t_in, t_out, cost = res
31
+
32
+ # Type/shape checks
33
+ if not isinstance(sql, str) or not sql.strip():
34
+ return StageResult(ok=False, error=["Generator produced empty or non-string SQL"])
35
+ if not sql.lower().lstrip().startswith("select"):
36
+ return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
37
+
38
+ rationale = rationale or "" # safe length
39
+ trace = StageTrace(
40
+ stage=self.name,
41
+ duration_ms=(time.perf_counter() - t0) * 1000.0,
42
+ token_in=t_in,
43
+ token_out=t_out,
44
+ cost_usd=cost,
45
+ notes={"rationale_len": len(rationale)},
46
+ )
47
+
48
+ return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
49
+
nl2sql/pipeline.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import traceback
3
+ from typing import Dict, Any, Optional, List
4
+ from nl2sql.types import StageResult
5
+ from nl2sql.ambiguity_detector import AmbiguityDetector
6
+ from nl2sql.planner import Planner
7
+ from nl2sql.generator import Generator
8
+ from nl2sql.safety import Safety
9
+ from nl2sql.executor import Executor
10
+ from nl2sql.verifier import Verifier
11
+ from nl2sql.repair import Repair
12
+
13
+
14
+ class Pipeline:
15
+ """
16
+ NL2SQL Copilot pipeline with guaranteed dict output.
17
+ All stages return structured traces and errors but final result is JSON-safe dict.
18
+ """
19
+
20
+ def __init__(self, *,
21
+ detector: AmbiguityDetector,
22
+ planner: Planner,
23
+ generator: Generator,
24
+ safety: Safety,
25
+ executor: Executor,
26
+ verifier: Verifier,
27
+ repair: Repair):
28
+ self.detector = detector
29
+ self.planner = planner
30
+ self.generator = generator
31
+ self.safety = safety
32
+ self.executor = executor
33
+ self.verifier = verifier
34
+ self.repair = repair
35
+
36
+ # ------------------------------------------------------------
37
+ def _trace_list(self, *stages: StageResult) -> List[dict]:
38
+ traces = []
39
+ for s in stages:
40
+ if not s:
41
+ continue
42
+ t = getattr(s, "trace", None)
43
+ if t:
44
+ traces.append(t.__dict__)
45
+ return traces
46
+
47
+ # ------------------------------------------------------------
48
+ def _safe_stage(self, fn, **kwargs) -> StageResult:
49
+ """Run a stage safely; if it throws, catch and convert to StageResult."""
50
+ try:
51
+ r = fn(**kwargs)
52
+ if isinstance(r, StageResult):
53
+ return r
54
+ else:
55
+ # not ideal, but wrap it
56
+ return StageResult(ok=True, data=r, trace=None)
57
+ except Exception as e:
58
+ tb = traceback.format_exc()
59
+ return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
60
+
61
+ # ------------------------------------------------------------
62
+ def run(self, *, user_query: str, schema_preview: str,
63
+ clarify_answers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
64
+ """
65
+ Always returns:
66
+ {
67
+ "ambiguous": bool,
68
+ "error": bool,
69
+ "details": list[str] | None,
70
+ "sql": str | None,
71
+ "rationale": str | None,
72
+ "verified": bool | None,
73
+ "traces": list[dict]
74
+ }
75
+ """
76
+ traces: List[dict] = []
77
+ details: List[str] = []
78
+ sql, rationale, verified = None, None, None
79
+
80
+ # --- 1) ambiguity detection
81
+ try:
82
+ questions = self.detector.detect(user_query, schema_preview)
83
+ if questions:
84
+ return {
85
+ "ambiguous": True,
86
+ "error": False,
87
+ "details": [f"Ambiguities found: {len(questions)}"],
88
+ "questions": questions,
89
+ "traces": []
90
+ }
91
+ except Exception as e:
92
+ return {"ambiguous": True, "error": True, "details": [f"Detector failed: {e}"], "traces": []}
93
+
94
+ # --- 2) planner
95
+ r_plan = self._safe_stage(self.planner.run, user_query=user_query, schema_preview=schema_preview)
96
+ traces.extend(self._trace_list(r_plan))
97
+ if not r_plan.ok:
98
+ return {"ambiguous": False, "error": True, "details": r_plan.errors, "traces": traces}
99
+
100
+ # --- 3) generator
101
+ r_gen = self._safe_stage(self.generator.run,
102
+ user_query=user_query,
103
+ schema_preview=schema_preview,
104
+ plan_text=r_plan.data.get("plan"),
105
+ clarify_answers=clarify_answers or {})
106
+ traces.extend(self._trace_list(r_gen))
107
+ if not r_gen.ok:
108
+ return {"ambiguous": False, "error": True, "details": r_gen.errors, "traces": traces}
109
+ sql = r_gen.data.get("sql")
110
+ rationale = r_gen.data.get("rationale")
111
+
112
+ # --- 4) safety
113
+ r_safe = self._safe_stage(self.safety.check, sql=sql)
114
+ traces.extend(self._trace_list(r_safe))
115
+ if not r_safe.ok:
116
+ return {"ambiguous": False, "error": True, "details": r_safe.errors, "traces": traces}
117
+
118
+ # --- 5) executor
119
+ r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
120
+ traces.extend(self._trace_list(r_exec))
121
+ if not r_exec.ok:
122
+ details.extend(r_exec.errors or [])
123
+
124
+ # --- 6) verifier
125
+ r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
126
+ traces.extend(self._trace_list(r_ver))
127
+ verified = bool(r_ver.ok)
128
+
129
+ # --- 7) repair loop if verification failed
130
+ if not verified:
131
+ for attempt in range(2):
132
+ r_fix = self._safe_stage(self.repair.run,
133
+ sql=sql,
134
+ error_msg="; ".join(details or ["unknown"]),
135
+ schema_preview=schema_preview)
136
+ traces.extend(self._trace_list(r_fix))
137
+ if not r_fix.ok:
138
+ break
139
+ sql = r_fix.data.get("sql")
140
+ r_safe = self._safe_stage(self.safety.check, sql=sql)
141
+ traces.extend(self._trace_list(r_safe))
142
+ if not r_safe.ok:
143
+ details.extend(r_safe.errors or [])
144
+ continue
145
+ r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
146
+ traces.extend(self._trace_list(r_exec))
147
+ if not r_exec.ok:
148
+ details.extend(r_exec.errors or [])
149
+ continue
150
+ r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
151
+ traces.extend(self._trace_list(r_ver))
152
+ verified = bool(r_ver.ok)
153
+ if verified:
154
+ break
155
+
156
+ # --- Final result dict
157
+ return {
158
+ "ambiguous": False,
159
+ "error": len(details) > 0 and not verified,
160
+ "details": details or None,
161
+ "sql": sql,
162
+ "rationale": rationale,
163
+ "verified": verified,
164
+ "traces": traces,
165
+ }
nl2sql/planner.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import time
3
+ from nl2sql.types import StageResult, StageTrace
4
+ from adapters.llm.base import LLMProvider
5
+
6
+ class Planner:
7
+ name = "planner"
8
+ def __init__(self, llm: LLMProvider) -> None:
9
+ self.llm = llm
10
+
11
+ def run(self, *, user_query: str, schema_preview: str) -> StageResult:
12
+ t0 = time.perf_counter()
13
+ plan_text, t_in, t_out, cost = self.llm.plan(user_query=user_query, schema_preview=schema_preview)
14
+ trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
15
+ token_in=t_in, token_out=t_out, cost_usd=cost, notes={"len_plan": len(plan_text)})
16
+ return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
nl2sql/repair.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+
4
+ from nl2sql.types import StageTrace, StageResult
5
+ from adapters.llm.base import LLMProvider
6
+
7
+ GUIDELINES = """
8
+ When repairing:
9
+ 1. Keep query SELECT-only.
10
+ 2. Explicitly qualify ambiguous columns with table names.
11
+ 3. Match GROUP BY fields with aggregations.
12
+ 4. Use known foreign keys for JOIN.
13
+ 5. Add a reasonable LIMIT if missing.
14
+ Return only the corrected SQL.
15
+ """
16
+
17
+ class Repair:
18
+ name = "repair"
19
+ def __init__(self, llm: LLMProvider):
20
+ self.llm = llm
21
+
22
+ def run(self, sql:str, error_msg: str, schema_preview: str) -> StageResult:
23
+ t0 = time.perf_counter()
24
+ fixed_sql, t_in, t_out, cost = self.llm.repair(sql=sql, error_msg=f"{GUIDELINES}\n\n{error_msg}",
25
+ schema_preview=schema_preview)
26
+ trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
27
+ token_in=t_in, token_out=t_out, cost_usd=cost,
28
+ notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)})
29
+ return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)
nl2sql/safety.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import re, time
3
+ from nl2sql.types import StageResult, StageTrace
4
+
5
+ # --- Regex utils ---
6
+ _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
7
+ _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
8
+ # string literals (single & double quotes), allow escaped quotes
9
+ _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
10
+ _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
11
+
12
+ # case-insensitive, word-boundary forbidden keywords
13
+ _FORBIDDEN = re.compile(
14
+ r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
15
+ re.IGNORECASE,
16
+ )
17
+
18
+ # allow: SELECT ... or WITH <cte...> SELECT ...
19
+ _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
20
+
21
+ def _strip_comments(s: str) -> str:
22
+ s = _COMMENT_BLOCK.sub(" ", s)
23
+ s = _COMMENT_LINE.sub(" ", s)
24
+ return s
25
+
26
+ def _mask_strings(s: str) -> str:
27
+ s = _STRING_SINGLE.sub("'X'", s)
28
+ s = _STRING_DOUBLE.sub('"X"', s)
29
+ return s
30
+
31
+ def _split_statements(s: str) -> list[str]:
32
+ parts = [p.strip() for p in s.split(";")]
33
+ return [p for p in parts if p]
34
+
35
+ class Safety:
36
+ name = "safety"
37
+
38
+ def check(self, sql: str) -> StageResult:
39
+ t0 = time.perf_counter()
40
+ print("🧩 SQL candidate:", sql)
41
+ s = _strip_comments(sql)
42
+ s = _mask_strings(s).strip()
43
+
44
+ stmts = _split_statements(s)
45
+ if len(stmts) != 1:
46
+ return StageResult(
47
+ ok=False,
48
+ error=["Multiple statements detected"],
49
+ trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
50
+ )
51
+
52
+ body = stmts[0]
53
+
54
+ if _FORBIDDEN.search(body):
55
+ return StageResult(
56
+ ok=False,
57
+ error=["Forbidden keyword detected"],
58
+ trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
59
+ )
60
+
61
+ if not _ALLOW_SELECT.match(body):
62
+ return StageResult(
63
+ ok=False,
64
+ error=["Non-SELECT statement"],
65
+ trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
66
+ )
67
+
68
+ return StageResult(
69
+ ok=True,
70
+ data={
71
+ "sql": sql.strip(),
72
+ "rationale": "Statement validated as SELECT-only (strings/comments ignored).",
73
+ },
74
+ trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
75
+ )
nl2sql/stubs.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.types import StageResult, StageTrace
2
+
3
+ class NoOpExecutor:
4
+ name = "executor"
5
+ def run(self, sql: str) -> StageResult:
6
+ # pretend success, return empty result set
7
+ return StageResult(
8
+ ok=True,
9
+ data={"rows": [], "columns": []},
10
+ trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
11
+ )
12
+
13
+ class NoOpVerifier:
14
+ name = "verifier"
15
+ def run(self, sql: str, exec_result: StageResult) -> StageResult:
16
+ # always verified for legacy tests
17
+ return StageResult(
18
+ ok=True,
19
+ data={"verified": True},
20
+ trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
21
+ )
22
+
23
+ class NoOpRepair:
24
+ name = "repair"
25
+ def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
26
+ # return original SQL unchanged
27
+ return StageResult(
28
+ ok=True,
29
+ data={"sql": sql},
30
+ trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
31
+ )
nl2sql/types.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional, List
3
+
4
+ @dataclass(frozen=True)
5
+ class StageTrace:
6
+ stage: str
7
+ duration_ms: float
8
+ notes: Optional[Dict[str, Any]] = None
9
+ token_in: Optional[int] = None
10
+ token_out: Optional[int] = None
11
+ cost_usd: Optional[float] = None
12
+
13
+ @dataclass(frozen=True)
14
+ class StageResult:
15
+ ok: bool
16
+ data: Optional[Any] = None
17
+ trace: Optional[StageTrace] = None
18
+ error: Optional[List[str]] = None
19
+ notes: Optional[Dict[str, Any]] = None
nl2sql/verifier.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlglot
2
+ from sqlglot import expressions as exp
3
+ from nl2sql.types import StageResult, StageTrace
4
+
5
+ class Verifier:
6
+ name = "verifier"
7
+
8
+ def run(self, sql: str, exec_result: StageResult) -> StageResult:
9
+ if not exec_result.ok:
10
+ return StageResult(ok=False, data=None,
11
+ trace=StageTrace(stage=self.name, duration_ms=0,
12
+ notes={"reason": "execution_error"}),
13
+ error=exec_result.errors)
14
+
15
+ # Rule 1: check SELECT / GROUP consistency
16
+ issues = []
17
+ try:
18
+ tree = sqlglot.parse_one(sql)
19
+ if isinstance(tree, exp.Select):
20
+ group = tree.args.get("group")
21
+ aggs = [a for a in tree.find_all(exp.AggFunc)]
22
+ if aggs and not group:
23
+ issues.append("Aggregation without GROUP BY.")
24
+ except Exception as e:
25
+ issues.append(f"Parse error during verification: {e}")
26
+
27
+ if issues:
28
+ return StageResult(ok=False, data=None,
29
+ trace=StageTrace(stage=self.name, duration_ms=0,
30
+ notes={"issues": issues}),
31
+ error=issues)
32
+ return StageResult(ok=True, data={"verified": True},
33
+ trace=StageTrace(stage=self.name, duration_ms=0))
requirements.txt CHANGED
@@ -1,8 +1,11 @@
1
- gradio
2
- langchain
3
- langchain-openai
4
- langchain_community
5
- sqlglot
6
- openai
7
- python-dotenv
8
- dotenv
 
 
 
 
1
+ fastapi==0.115.2
2
+ uvicorn[standard]==0.30.6
3
+ pydantic==2.9.2
4
+ sqlglot==27.26.0
5
+ requests==2.32.3
6
+ streamlit==1.39.0
7
+ plotly==5.24.1
8
+ pytest==8.3.3
9
+ python-dotenv==1.1.1
10
+ openai==2.6.1
11
+ psycopg[binary]~=3.2
tests/conftest.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
+ ENV_PATH = os.path.join(ROOT_DIR, ".env")
6
+
7
+ load_dotenv(dotenv_path=ENV_PATH)
tests/test_ambiguity.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.ambiguity_detector import AmbiguityDetector
2
+ from nl2sql.types import StageResult
3
+ from app.routers import nl2sql
4
+
5
+ def test_detects_ambiguous_terms():
6
+ det = AmbiguityDetector()
7
+ res = det.detect("Show me recent top singers", "table: singer(id,name,age)")
8
+ assert len(res) >= 1
9
+ assert "recent" in res[0].lower()
10
+
11
+ def test_not_false_positive():
12
+ det = AmbiguityDetector()
13
+ res = det.detect("List all singers older than 30", "table: singer(id, name, age)")
14
+ assert res == []
15
+
16
+ def test_ambiguity_response():
17
+ fake_result = StageResult(ok=True, data={"ambiguous": True, "questions": ["Clarify column?"]})
18
+ response = nl2sql._to_dict(fake_result.data)
19
+ assert response["ambiguous"] is True
tests/test_executor.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.executor import Executor
2
+ from adapters.db.sqlite_adapter import SQLiteAdapter
3
+
4
+ def test_executor_runs_select(tmp_path):
5
+ db_path = tmp_path / "test.db"
6
+ import sqlite3
7
+ conn = sqlite3.connect(db_path)
8
+ conn.execute("CREATE TABLE users(id INT, name TEXT);")
9
+ conn.execute("INSERT INTO users VALUES (1, 'Alice');")
10
+ conn.commit()
11
+ conn.close()
12
+
13
+ ex = Executor(SQLiteAdapter(str(db_path)))
14
+ res = ex.run("SELECT * FROM users;")
15
+ assert res.ok
16
+ assert res.data["rows"][0][1] == "Alice"
tests/test_generator.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from nl2sql.generator import Generator
3
+ from nl2sql.types import StageResult
4
+
5
+
6
+ # --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
7
+
8
+ class LLM_OK:
9
+ def generate_sql(self, **kwargs):
10
+ # contract: (sql, rationale, t_in, t_out, cost)
11
+ return "SELECT * FROM singer;", "list all", 10, 5, 0.00001
12
+
13
+
14
+ class LLM_EMPTY_SQL:
15
+ def generate_sql(self, **kwargs):
16
+ # empty SQL → should be error
17
+ return "", "reason", 10, 5, 0.0
18
+
19
+
20
+ class LLM_NON_SELECT:
21
+ def generate_sql(self, **kwargs):
22
+ # non-SELECT SQL → should be error
23
+ return "UPDATE users SET name='x' WHERE id=1;", "bad", 8, 3, 0.0
24
+
25
+
26
+ class LLM_CONTRACT_NONE:
27
+ def generate_sql(self, **kwargs):
28
+ # contract violation: None instead of 5-tuple
29
+ return None
30
+
31
+
32
+ class LLM_CONTRACT_SHORT:
33
+ def generate_sql(self, **kwargs):
34
+ # contract violation: too few items
35
+ return ("SELECT * FROM singer;", "list all") # only 2
36
+
37
+
38
+ # --- Parametrized negative cases ----------------------------------------------
39
+
40
+ @pytest.mark.parametrize(
41
+ "llm, err_keyword",
42
+ [
43
+ (LLM_EMPTY_SQL(), "empty"), # empty or non-string sql
44
+ (LLM_NON_SELECT(), "non-select"), # generated non-SELECT
45
+ (LLM_CONTRACT_NONE(), "contract violation"),
46
+ (LLM_CONTRACT_SHORT(), "contract violation"),
47
+ ],
48
+ )
49
+ def test_generator_errors_do_not_create_trace(llm, err_keyword):
50
+ gen = Generator(llm=llm)
51
+ r = gen.run(
52
+ user_query="show all singers",
53
+ schema_preview="CREATE TABLE singer(id int, name text);",
54
+ plan_text="-- plan --",
55
+ clarify_answers={}
56
+ )
57
+ assert isinstance(r, StageResult)
58
+ assert r.ok is False
59
+ # Error message is flexible; just check a keyword
60
+ joined = " ".join(r.error or []).lower()
61
+ assert err_keyword in joined
62
+ # On errors, Generator should not attach a trace (we measure only successful stage)
63
+ assert r.trace is None
64
+
65
+
66
+ # --- Positive case (success) ---------------------------------------------------
67
+
68
+ def test_generator_success_has_valid_trace_and_data():
69
+ gen = Generator(llm=LLM_OK())
70
+ r = gen.run(
71
+ user_query="show all singers",
72
+ schema_preview="CREATE TABLE singer(id int, name text);",
73
+ plan_text="-- plan --",
74
+ clarify_answers={}
75
+ )
76
+
77
+ # Basic success checks
78
+ assert isinstance(r, StageResult)
79
+ assert r.ok is True
80
+ assert r.data and r.data["sql"].lower().startswith("select")
81
+ assert "rationale" in r.data
82
+
83
+ # Trace should exist and be coherent
84
+ assert r.trace is not None
85
+ assert r.trace.stage == "generator"
86
+ assert isinstance(r.trace.duration_ms, float)
87
+ assert r.trace.token_in == 10
88
+ assert r.trace.token_out == 5
89
+ # cost can be float or None depending on provider; if present must be numeric
90
+ if r.trace.cost_usd is not None:
91
+ assert isinstance(r.trace.cost_usd, float)
92
+
93
+ # Optional notes check – rationale_len should match length of rationale
94
+ notes = r.trace.notes or {}
95
+ if "rationale_len" in notes:
96
+ assert notes["rationale_len"] == len(r.data.get("rationale", ""))
tests/test_nl2sql_router.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from app.main import app
4
+ from nl2sql.types import StageResult, StageTrace
5
+
6
+ client = TestClient(app)
7
+
8
+
9
+ def fake_trace(stage: str):
10
+ return StageTrace(stage=stage, duration_ms=10.0)
11
+
12
+ path = app.url_path_for("nl2sql_handler")
13
+
14
+ # --- 1) Clarify / ambiguity case ---------------------------------------------
15
+ def test_ambiguity_route(monkeypatch):
16
+ from app.routers import nl2sql
17
+
18
+ # mock pipeline to return StageResult with ambiguous=True
19
+ def fake_run(*args, **kwargs):
20
+ return StageResult(
21
+ ok=True,
22
+ data={
23
+ "ambiguous": True,
24
+ "questions": ["Which table do you mean?"],
25
+ "traces": [fake_trace("detector")],
26
+ },
27
+ )
28
+
29
+ monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
30
+
31
+ resp = client.post(
32
+ path,
33
+ json={
34
+ "query": "show all records",
35
+ "schema_preview": "CREATE TABLE ...",
36
+ },
37
+ )
38
+
39
+ assert resp.status_code == 200
40
+ data = resp.json()
41
+ assert data["ambiguous"] is True
42
+ assert "questions" in data
43
+
44
+
45
+ # --- 2) Error / failure case -------------------------------------------------
46
+ def test_error_route(monkeypatch):
47
+ from app.routers import nl2sql
48
+
49
+ def fake_run(*args, **kwargs):
50
+ return StageResult(ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]})
51
+
52
+ monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
53
+
54
+ resp = client.post(
55
+ path,
56
+ json={
57
+ "query": "drop table users;",
58
+ "schema_preview": "CREATE TABLE users(id int);",
59
+ },
60
+ )
61
+
62
+ assert resp.status_code == 400
63
+ assert "Bad SQL" in resp.json()["detail"]
64
+
65
+
66
+ # --- 3) Success / happy path -------------------------------------------------
67
+ def test_success_route(monkeypatch):
68
+ from app.routers import nl2sql
69
+
70
+ def fake_run(*args, **kwargs):
71
+ return StageResult(
72
+ ok=True,
73
+ data={
74
+ "ambiguous": False,
75
+ "sql": "SELECT * FROM users;",
76
+ "rationale": "Simple listing",
77
+ "traces": [fake_trace("planner"), fake_trace("generator")],
78
+ },
79
+ )
80
+
81
+ monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
82
+
83
+ resp = client.post(
84
+ path,
85
+ json={
86
+ "query": "show all users",
87
+ "schema_preview": "CREATE TABLE users(id int, name text);",
88
+ },
89
+ )
90
+
91
+ assert resp.status_code == 200
92
+ data = resp.json()
93
+ assert data["sql"].lower().startswith("select")
94
+ assert isinstance(data["traces"], list)
95
+ assert any(t["stage"] == "planner" for t in data["traces"])
tests/test_openai_provider.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pytest
3
+ from adapters.llm.openai_provider import OpenAIProvider
4
+
5
+
6
+ # Helper class to fake the completion object returned by OpenAI SDK
7
+ class FakeCompletion:
8
+ def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
9
+ self.choices = [type("Choice", (), {"message": type("Msg", (), {"content": content})})]
10
+ self.usage = type("Usage", (), {
11
+ "prompt_tokens": prompt_tokens,
12
+ "completion_tokens": completion_tokens
13
+ })
14
+
15
+
16
+ # --- Case 1: clean valid JSON --------------------------------------------------
17
+ def test_generate_sql_valid_json(monkeypatch):
18
+ provider = OpenAIProvider()
19
+
20
+ fake_content = json.dumps({
21
+ "sql": "SELECT * FROM singer;",
22
+ "rationale": "List all singers."
23
+ })
24
+ fake_completion = FakeCompletion(fake_content)
25
+
26
+ # Monkeypatch client.chat.completions.create
27
+ def fake_create(*args, **kwargs):
28
+ return fake_completion
29
+
30
+ monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
31
+
32
+ sql, rationale, t_in, t_out, cost = provider.generate_sql(
33
+ user_query="show all singers",
34
+ schema_preview="CREATE TABLE singer(id int, name text);",
35
+ plan_text="-- plan --",
36
+ clarify_answers={}
37
+ )
38
+
39
+ assert sql.strip().lower().startswith("select")
40
+ assert "singer" in sql.lower()
41
+ assert "list" in rationale.lower()
42
+ assert t_in == 5 and t_out == 7
43
+ assert isinstance(cost, float)
44
+
45
+
46
+ # --- Case 2: malformed JSON with extra text (should still recover) ------------
47
+ def test_generate_sql_recover_from_partial_json(monkeypatch):
48
+ provider = OpenAIProvider()
49
+
50
+ # invalid JSON with text around it
51
+ fake_content = "Here is the result:\n{ \"sql\": \"SELECT * FROM users;\", \"rationale\": \"list users\" }\nThanks!"
52
+ fake_completion = FakeCompletion(fake_content)
53
+
54
+ def fake_create(*args, **kwargs):
55
+ return fake_completion
56
+
57
+ monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
58
+
59
+ sql, rationale, *_ = provider.generate_sql(
60
+ user_query="show all users",
61
+ schema_preview="CREATE TABLE users(id int, name text);",
62
+ plan_text="-- plan --"
63
+ )
64
+
65
+ assert sql.lower().startswith("select")
66
+ assert "user" in sql.lower()
67
+ assert "list" in rationale.lower()
68
+
69
+
70
+ # --- Case 3: completely invalid JSON (should raise ValueError) ----------------
71
+ def test_generate_sql_invalid_json(monkeypatch):
72
+ provider = OpenAIProvider()
73
+
74
+ fake_content = "This is nonsense output without braces"
75
+ fake_completion = FakeCompletion(fake_content)
76
+
77
+ def fake_create(*args, **kwargs):
78
+ return fake_completion
79
+
80
+ monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
81
+
82
+ with pytest.raises(ValueError):
83
+ provider.generate_sql(
84
+ user_query="show X",
85
+ schema_preview="CREATE TABLE t(id int);",
86
+ plan_text="-- plan --"
87
+ )
tests/test_pipeline_integration.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from nl2sql.pipeline import Pipeline
3
+ from nl2sql.types import StageResult, StageTrace
4
+
5
+
6
+ # --- Dummy stages to isolate pipeline -----------------------------------------
7
+
8
+ class DummyDetector:
9
+ """Simulates ambiguity detector stage."""
10
+ def __init__(self, ambiguous=False):
11
+ self.ambiguous = ambiguous
12
+
13
+ def detect(self, user_query, schema_preview):
14
+ # If ambiguous=True, return clarification questions
15
+ return ["Which column?"] if self.ambiguous else []
16
+
17
+
18
+ class DummyPlanner:
19
+ """Simulates planner stage."""
20
+ def run(self, *, user_query, schema_preview):
21
+ trace = StageTrace(stage="planner", duration_ms=1.0)
22
+ if "fail_plan" in user_query:
23
+ return StageResult(ok=False, error=["Planner failed"], trace=trace)
24
+ return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)
25
+
26
+
27
+ class DummyGenerator:
28
+ """Simulates generator stage."""
29
+ def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
30
+ trace = StageTrace(stage="generator", duration_ms=1.0)
31
+ if "fail_gen" in user_query:
32
+ return StageResult(ok=False, error=["Generator failed"], trace=trace)
33
+ sql = "SELECT * FROM singer;"
34
+ rationale = "List all singers."
35
+ return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
36
+
37
+
38
+ class DummySafety:
39
+ """Simulates safety stage."""
40
+ def check(self, sql):
41
+ trace = StageTrace(stage="safety", duration_ms=1.0)
42
+ if "DROP" in sql.upper():
43
+ return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
44
+ return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
45
+
46
+
47
+ # --- 1) Success path ----------------------------------------------------------
48
+ def test_pipeline_success():
49
+ pipeline = Pipeline(
50
+ detector=DummyDetector(ambiguous=False),
51
+ planner=DummyPlanner(),
52
+ generator=DummyGenerator(),
53
+ safety=DummySafety()
54
+ )
55
+
56
+ r = pipeline.run(
57
+ user_query="show all singers",
58
+ schema_preview="CREATE TABLE singer(id int, name text);"
59
+ )
60
+
61
+ assert isinstance(r, StageResult)
62
+ assert r.ok is True
63
+ data = r.data or {}
64
+ assert data["sql"].lower().startswith("select")
65
+ assert any(t.stage == "planner" for t in data["traces"])
66
+ assert any(t.stage == "generator" for t in data["traces"])
67
+ assert any(t.stage == "safety" for t in data["traces"])
68
+
69
+
70
+ # --- 2) Ambiguity case --------------------------------------------------------
71
+ def test_pipeline_ambiguity():
72
+ pipeline = Pipeline(
73
+ detector=DummyDetector(ambiguous=True),
74
+ planner=DummyPlanner(),
75
+ generator=DummyGenerator(),
76
+ safety=DummySafety()
77
+ )
78
+
79
+ r = pipeline.run(
80
+ user_query="show data",
81
+ schema_preview="CREATE TABLE x(id int);"
82
+ )
83
+
84
+ assert isinstance(r, StageResult)
85
+ assert r.ok is True
86
+ assert r.data["ambiguous"] is True
87
+ assert isinstance(r.data["questions"], list)
88
+
89
+
90
+ # --- 3) Planner failure -------------------------------------------------------
91
+ def test_pipeline_plan_fail():
92
+ pipeline = Pipeline(
93
+ detector=DummyDetector(),
94
+ planner=DummyPlanner(),
95
+ generator=DummyGenerator(),
96
+ safety=DummySafety()
97
+ )
98
+ r = pipeline.run(
99
+ user_query="fail_plan",
100
+ schema_preview="CREATE TABLE singer(id int);"
101
+ )
102
+ assert isinstance(r, StageResult)
103
+ assert r.ok is False
104
+ assert "Planner failed" in " ".join(r.error or [])
105
+
106
+
107
+ # --- 4) Generator failure -----------------------------------------------------
108
+ def test_pipeline_gen_fail():
109
+ pipeline = Pipeline(
110
+ detector=DummyDetector(),
111
+ planner=DummyPlanner(),
112
+ generator=DummyGenerator(),
113
+ safety=DummySafety()
114
+ )
115
+ r = pipeline.run(
116
+ user_query="fail_gen",
117
+ schema_preview="CREATE TABLE singer(id int);"
118
+ )
119
+ assert r.ok is False
120
+ assert "Generator failed" in " ".join(r.error or [])
121
+
122
+
123
+ # --- 5) Safety failure --------------------------------------------------------
124
+ def test_pipeline_safety_fail():
125
+ class UnsafeGen(DummyGenerator):
126
+ def run(self, **kw):
127
+ trace = StageTrace(stage="generator", duration_ms=1.0)
128
+ # Generate a DROP TABLE → unsafe
129
+ return StageResult(ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace)
130
+
131
+ pipeline = Pipeline(
132
+ detector=DummyDetector(),
133
+ planner=DummyPlanner(),
134
+ generator=UnsafeGen(),
135
+ safety=DummySafety()
136
+ )
137
+ r = pipeline.run(
138
+ user_query="drop something",
139
+ schema_preview="CREATE TABLE x(id int);"
140
+ )
141
+ assert r.ok is False
142
+ assert "unsafe" in " ".join(r.error or []).lower()
tests/test_safety.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.safety import Safety
2
+ import pytest
3
+
4
+
5
+
6
+ def test_safety_allows_select():
7
+ s = Safety()
8
+ result = s.check("SELECT * FROM users;")
9
+ assert result.ok
10
+ assert "sql" in result.data
11
+ assert result.trace.stage == "safety"
12
+
13
+ def test_safety_allows_with_select_cte():
14
+ s = Safety()
15
+ sql = """
16
+ WITH recent AS (
17
+ SELECT id FROM users WHERE created_at > '2024-01-01'
18
+ )
19
+ SELECT * FROM users u JOIN recent r ON u.id = r.id;
20
+ """
21
+ r = s.check(sql)
22
+ assert r.ok
23
+
24
+ def test_safety_allows_select_with_comments_and_newlines():
25
+ s = Safety()
26
+ sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
27
+ r = s.check(sql)
28
+ assert r.ok
29
+
30
+ def test_safety_allows_keywords_inside_string_literals():
31
+ s = Safety()
32
+ sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
33
+ r = s.check(sql)
34
+ assert r.ok, r.error
35
+
36
+
37
+ def test_safety_blocks_delete():
38
+ s = Safety()
39
+ result = s.check("DELETE FROM users;")
40
+ assert not result.ok
41
+ assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
42
+
43
+ @pytest.mark.parametrize("sql", [
44
+ "UPDATE users SET name='X' WHERE id=1;",
45
+ "INSERT INTO users(id) VALUES (1);",
46
+ "DROP TABLE users;",
47
+ "CREATE TABLE x(id INT);",
48
+ "ALTER TABLE users ADD COLUMN x INT;",
49
+ "ATTACH DATABASE 'hack.db' AS h;",
50
+ "PRAGMA journal_mode=WAL;",
51
+ ])
52
+ def test_safety_blocks_forbidden_statements(sql):
53
+ s = Safety()
54
+ res = s.check(sql)
55
+ assert not res.ok
56
+
57
+ def test_safety_blocks_stacked_delete_after_select():
58
+ s = Safety()
59
+ sql = "SELECT * FROM users; DELETE FROM users;"
60
+ r = s.check(sql)
61
+ assert not r.ok
62
+
63
+ def test_safety_blocks_stacked_delete_with_spaces():
64
+ s = Safety()
65
+ sql = "SELECT * FROM users ; \n DELETE users;"
66
+ r = s.check(sql)
67
+ assert not r.ok
68
+
69
+ def test_safety_blocks_delete_inside_cte():
70
+ s = Safety()
71
+ sql = """
72
+ WITH bad AS (DELETE FROM users)
73
+ SELECT * FROM users;
74
+ """
75
+ r = s.check(sql)
76
+ assert not r.ok
77
+
78
+ @pytest.mark.parametrize("sql", [
79
+ "/*D*/ROP TABLE users;",
80
+ "PR/*x*/AGMA journal_mode=WAL;",
81
+ "AL/* comment */TER TABLE x ADD COLUMN y INT;",
82
+ ])
83
+ def test_safety_blocks_comment_obfuscation(sql):
84
+ s = Safety()
85
+ r = s.check(sql)
86
+ assert not r.ok
87
+
88
+ @pytest.mark.parametrize("sql", [
89
+ "pragma journal_mode=WAL;", # lower-case
90
+ " PRAGMA user_version = 5 ; ",
91
+ "\nATTACH DATABASE 'hack.db' AS h;",
92
+ ])
93
+ def test_safety_blocks_forbidden_case_and_spacing(sql):
94
+ s = Safety()
95
+ r = s.check(sql)
96
+ assert not r.ok
97
+
98
+ def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
99
+ s = Safety()
100
+ sql = "SELECT 1; -- now do something bad\n"
101
+ sql_bad = "SELECT 1; /* spacer */ DROP TABLE x;"
102
+ assert s.check(sql).ok
103
+ assert not s.check(sql_bad).ok
tests/test_stage_types.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.types import StageResult, StageTrace
2
+
3
+ def test_error_response():
4
+ r = StageResult(ok=False, error=["Syntax error"])
5
+ assert not r.ok
6
+ assert r.error == ["Syntax error"]
7
+
8
+ def test_trace_dataclass_structure():
9
+ t = StageTrace(stage="planner", duration_ms=12.5, token_in=10, token_out=20)
10
+ assert t.stage == "planner"
11
+ assert isinstance(t.duration_ms, float)
12
+ assert t.token_out == 20
13
+
14
+ def test_stage_result_defaults():
15
+ r = StageResult(ok=True)
16
+ assert r.ok
17
+ assert r.data is None
18
+ assert r.error is None
ui/benchmark_app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import streamlit as st
4
+ import plotly.express as px
5
+ from pathlib import Path
6
+
7
+ st.set_page_config(page_title="NL2SQL Benchmark Dashboard", layout="wide")
8
+
9
+ st.title("📊 NL2SQL Copilot – Benchmark Dashboard")
10
+
11
+ # 1. Load results
12
+ result_files = list(Path("benchmarks/results").glob("*.jsonl"))
13
+ if not result_files:
14
+ st.warning("No benchmark result files found in benchmarks/results/")
15
+ st.stop()
16
+
17
+ file = st.selectbox("Select benchmark file", result_files)
18
+ rows = [json.loads(l) for l in open(file)]
19
+ df = pd.DataFrame(rows)
20
+
21
+ # 2. Summary metrics
22
+ st.subheader("Aggregate Metrics")
23
+ col1, col2, col3, col4 = st.columns(4)
24
+ col1.metric("Total Queries", len(df))
25
+ col2.metric("Execution Accuracy", f"{df['exec_acc'].mean()*100:.1f}%")
26
+ col3.metric("Safety Violations", f"{df['safe_fail'].mean()*100:.1f}%")
27
+ col4.metric("Average Latency (ms)", f"{df['latency_ms'].mean():.0f}")
28
+
29
+ # 3. Latency Distribution
30
+ st.subheader("Latency Distribution")
31
+ fig1 = px.histogram(df, x="latency_ms", nbins=30, title="Latency Histogram")
32
+ st.plotly_chart(fig1, use_container_width=True)
33
+
34
+ # 4. Cost vs Accuracy
35
+ st.subheader("Cost vs Execution Accuracy")
36
+ fig2 = px.scatter(df, x="cost_usd", y="exec_acc", color="provider",
37
+ title="Trade-off: Cost vs Accuracy", hover_data=["query"])
38
+ st.plotly_chart(fig2, use_container_width=True)
39
+
40
+ # 5. Repair Stats
41
+ if "repair_attempts" in df.columns:
42
+ st.subheader("Repair Attempts")
43
+ fig3 = px.bar(df.groupby("repair_attempts").size().reset_index(name="count"),
44
+ x="repair_attempts", y="count", title="Number of Repair Attempts per Query")
45
+ st.plotly_chart(fig3, use_container_width=True)