nl2sql-copilot / app.py
Melika Kheirieh
Add first benchmark
e207f41
raw
history blame
7.53 kB
from config import (
LLM_MODEL,
LLM_TEMPERATURE,
FORBIDDEN_KEYWORDS,
FORBIDDEN_TABLES
)
import os
import sqlite3
import json
import re
from typing import Optional, Tuple, List
import gradio as gr
import sqlglot
from sqlglot import exp
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain.prompts import ChatPromptTemplate
def get_readonly_sqlite_url(db_path: str) -> str:
return f"file:{db_path}?mode=ro&uri=true"
def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str:
uri = get_readonly_sqlite_url(db_path)
with sqlite3.connect(uri, uri=True, timeout=3) as conn:
conn.row_factory = sqlite3.Row
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
tables = [r["name"] for r in cur.fetchall()]
lines = []
for t in tables:
# skip SQLite internals
if t in FORBIDDEN_TABLES:
continue
cur.execute(f"PRAGMA table_info({t});")
cols = cur.fetchall()
col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols])
lines.append(f"- {t} ({col_line})")
if limit_per_table > 0:
try:
cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};")
sample = cur.fetchall()
if sample:
lines.append(f" sample rows: {len(sample)}")
except Exception:
pass
if not lines:
return "(no user tables found)"
return "\n".join(lines)
def validate_sql_safe(sql: str) -> Tuple[bool, str]:
if sql.count(";") > 0:
if sql.strip().endswith(";"):
if sql.strip()[:-1].count(";") > 0:
return False, "Multiple statements are not allowed."
else:
return False, "Multiple statements are not allowed."
upper = re.sub(r"\s+", " ", sql).strip()
for kw in FORBIDDEN_KEYWORDS:
if re.search(rf"\b{kw}\b", upper):
return False, f"Keyword '{kw}' is not allowed."
try:
parsed = sqlglot.parse(sql, read='sqlite')
except Exception as e:
return False, f"SQL parse error: {e}"
if not parsed or len(parsed) != 1:
return False, "Exactly one SQL statement is allowed."
stmt = parsed[0]
if not isinstance(stmt, exp.Select):
return False, "Only SELECT statements are allowed."
for table in stmt.find_all(exp.Table):
table_name = table.name.lower() if table.name else ""
if table_name in FORBIDDEN_TABLES:
return False, f"Access to {table_name} is not allowed."
return True, "OK"
def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]:
uri = get_readonly_sqlite_url(db_path)
if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE):
sql = f"{sql.rstrip(';')} LIMIT {max_rows}"
with sqlite3.connect(uri, uri=True, timeout=timeout) as conn:
conn.row_factory = sqlite3.Row
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
if rows:
cols = rows[0].keys()
data = [list(r) for r in rows]
return list(cols), data
else:
return [], []
custom_prompt = ChatPromptTemplate.from_template("""
Given the following question, return ONLY a valid SQL query in JSON form.
Question: {input}
Database schema: {table_info}
You may sample/preview at most {top_k} rows if you need examples.
Respond in this exact JSON format:
{{
"sql": "<SQL_QUERY_HERE>"
}}
""")
def make_sql_chain(sql_db: SQLDatabase):
assert hasattr(sql_db, "get_table_info"), "Expected LangChain SQLDatabase"
llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
return chain
def on_upload_database(db_file, state):
if db_file is None:
return state, "No file provided.", "(no schema)"
path = db_file.name
sql_db = SQLDatabase.from_uri(f"sqlite:///{path}")
schema_text = get_schema_preview(path, limit_per_table=0)
chain = make_sql_chain(sql_db)
new_state = {
"db_path": path,
"sql_db": sql_db,
"schema_text": schema_text,
"chain": chain,
}
return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text
def extract_sql_safe(output_text: str) -> str:
try:
obj = json.loads(output_text)
if isinstance(obj, dict) and "sql" in obj:
return obj["sql"].strip()
except Exception:
pass
m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE)
if m:
return m.group(1).strip()
return output_text.strip()
def on_generate_query(question , max_rows, state):
if not state or not state.get("db_path") or not state.get("chain"):
return "Please upload a database first.", "", ""
if not question or not question.strip():
return "Please enter a question.", "", ""
try:
generated_sql = state["chain"].invoke({"question": question})
sql = extract_sql_safe(str(generated_sql))
ok, msg = validate_sql_safe(sql)
if not ok:
return f"Blocked SQL: {msg}", sql, ""
cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows)
if not cols:
return f"No rows returned.", sql, "[]"
sample = [dict(zip(cols, r)) for r in rows[:50]]
return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2)
except Exception as e:
return f"Error: {e}", "", ""
with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo:
gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)")
gr.Markdown(
"Upload a **SQLite** file, ask a question in natural language, "
"and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, "
"and (4) show you the results."
)
state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None})
with gr.Row():
db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"])
upload_status = gr.Textbox(label="upload Status", interactive=False)
schema_box = gr.Accordion("Database schema (preview)", open=False)
with schema_box:
schema_md = gr.Markdown("(no schema)")
gr.Markdown("---")
with gr.Row():
question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales")
with gr.Row():
max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows")
with gr.Row():
run_btn = gr.Button("Generate & Run SQL", variant="primary")
with gr.Row():
status_out = gr.Textbox(label="Status")
with gr.Row():
sql_out = gr.Code(label="Generated SQL (validated)")
with gr.Row():
result_out = gr.Code(label="Result (JSON sample)")
db_file.change(
fn=on_upload_database,
inputs=[db_file, state],
outputs=[state, upload_status, schema_md],
)
run_btn.click(
fn=on_generate_query,
inputs=[question, max_row, state],
outputs=[status_out, sql_out, result_out],
)
if __name__ == "__main__":
demo.launch()