from config import ( LLM_MODEL, LLM_TEMPERATURE, FORBIDDEN_KEYWORDS, FORBIDDEN_TABLES ) import os import sqlite3 import json import re from typing import Optional, Tuple, List import gradio as gr import sqlglot from sqlglot import exp from langchain_openai import ChatOpenAI from langchain_community.utilities import SQLDatabase from langchain.chains import create_sql_query_chain from langchain.prompts import ChatPromptTemplate def get_readonly_sqlite_url(db_path: str) -> str: return f"file:{db_path}?mode=ro&uri=true" def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str: uri = get_readonly_sqlite_url(db_path) with sqlite3.connect(uri, uri=True, timeout=3) as conn: conn.row_factory = sqlite3.Row cur = conn.cursor() cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;") tables = [r["name"] for r in cur.fetchall()] lines = [] for t in tables: # skip SQLite internals if t in FORBIDDEN_TABLES: continue cur.execute(f"PRAGMA table_info({t});") cols = cur.fetchall() col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols]) lines.append(f"- {t} ({col_line})") if limit_per_table > 0: try: cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};") sample = cur.fetchall() if sample: lines.append(f" sample rows: {len(sample)}") except Exception: pass if not lines: return "(no user tables found)" return "\n".join(lines) def validate_sql_safe(sql: str) -> Tuple[bool, str]: if sql.count(";") > 0: if sql.strip().endswith(";"): if sql.strip()[:-1].count(";") > 0: return False, "Multiple statements are not allowed." else: return False, "Multiple statements are not allowed." upper = re.sub(r"\s+", " ", sql).strip() for kw in FORBIDDEN_KEYWORDS: if re.search(rf"\b{kw}\b", upper): return False, f"Keyword '{kw}' is not allowed." try: parsed = sqlglot.parse(sql, read='sqlite') except Exception as e: return False, f"SQL parse error: {e}" if not parsed or len(parsed) != 1: return False, "Exactly one SQL statement is allowed." stmt = parsed[0] if not isinstance(stmt, exp.Select): return False, "Only SELECT statements are allowed." for table in stmt.find_all(exp.Table): table_name = table.name.lower() if table.name else "" if table_name in FORBIDDEN_TABLES: return False, f"Access to {table_name} is not allowed." return True, "OK" def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]: uri = get_readonly_sqlite_url(db_path) if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE): sql = f"{sql.rstrip(';')} LIMIT {max_rows}" with sqlite3.connect(uri, uri=True, timeout=timeout) as conn: conn.row_factory = sqlite3.Row cur = conn.cursor() cur.execute(sql) rows = cur.fetchall() if rows: cols = rows[0].keys() data = [list(r) for r in rows] return list(cols), data else: return [], [] custom_prompt = ChatPromptTemplate.from_template(""" Given the following question, return ONLY a valid SQL query in JSON form. Question: {input} Database schema: {table_info} You may sample/preview at most {top_k} rows if you need examples. Respond in this exact JSON format: {{ "sql": "" }} """) def make_sql_chain(sql_db: SQLDatabase): llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE) chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20) return chain def on_upload_database(db_file, state): if db_file is None: return state, "No file provided.", "(no schema)" path = db_file.name sql_db = SQLDatabase.from_uri(f"sqlite:///{path}") schema_text = get_schema_preview(path, limit_per_table=0) chain = make_sql_chain(sql_db) new_state = { "db_path": path, "sql_db": sql_db, "schema_text": schema_text, "chain": chain, } return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text def extract_sql_safe(output_text: str) -> str: try: obj = json.loads(output_text) if isinstance(obj, dict) and "sql" in obj: return obj["sql"].strip() except Exception: pass m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE) if m: return m.group(1).strip() return output_text.strip() def on_generate_query(question , max_rows, state): if not state or not state.get("db_path") or not state.get("chain"): return "Please upload a database first.", "", "" if not question or not question.strip(): return "Please enter a question.", "", "" try: generated_sql = state["chain"].invoke({"question": question}) sql = extract_sql_safe(str(generated_sql)) ok, msg = validate_sql_safe(sql) if not ok: return f"Blocked SQL: {msg}", sql, "" cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows) if not cols: return f"No rows returned.", sql, "[]" sample = [dict(zip(cols, r)) for r in rows[:50]] return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2) except Exception as e: return f"Error: {e}", "", "" with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo: gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)") gr.Markdown( "Upload a **SQLite** file, ask a question in natural language, " "and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, " "and (4) show you the results." ) state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None}) with gr.Row(): db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"]) upload_status = gr.Textbox(label="upload Status", interactive=False) schema_box = gr.Accordion("Database schema (preview)", open=False) with schema_box: schema_md = gr.Markdown("(no schema)") gr.Markdown("---") with gr.Row(): question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales") with gr.Row(): max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows") with gr.Row(): run_btn = gr.Button("Generate & Run SQL", variant="primary") with gr.Row(): status_out = gr.Textbox(label="Status") with gr.Row(): sql_out = gr.Code(label="Generated SQL (validated)") with gr.Row(): result_out = gr.Code(label="Result (JSON sample)") db_file.change( fn=on_upload_database, inputs=[db_file, state], outputs=[state, upload_status, schema_md], ) run_btn.click( fn=on_generate_query, inputs=[question, max_row, state], outputs=[status_out, sql_out, result_out], ) if __name__ == "__main__": demo.launch()