| from config import ( |
| LLM_MODEL, |
| LLM_TEMPERATURE, |
| FORBIDDEN_KEYWORDS, |
| FORBIDDEN_TABLES |
| ) |
| import os |
| import sqlite3 |
| import json |
| import re |
| from typing import Optional, Tuple, List |
|
|
| import gradio as gr |
| import sqlglot |
| from sqlglot import exp |
|
|
| from langchain_openai import ChatOpenAI |
| from langchain_community.utilities import SQLDatabase |
| from langchain.chains import create_sql_query_chain |
| from langchain.prompts import ChatPromptTemplate |
|
|
|
|
| def get_readonly_sqlite_url(db_path: str) -> str: |
| return f"file:{db_path}?mode=ro&uri=true" |
|
|
| def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str: |
| uri = get_readonly_sqlite_url(db_path) |
| with sqlite3.connect(uri, uri=True, timeout=3) as conn: |
| conn.row_factory = sqlite3.Row |
| cur = conn.cursor() |
| cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;") |
| tables = [r["name"] for r in cur.fetchall()] |
| lines = [] |
| for t in tables: |
| |
| if t in FORBIDDEN_TABLES: |
| continue |
| cur.execute(f"PRAGMA table_info({t});") |
| cols = cur.fetchall() |
| col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols]) |
| lines.append(f"- {t} ({col_line})") |
| if limit_per_table > 0: |
| try: |
| cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};") |
| sample = cur.fetchall() |
| if sample: |
| lines.append(f" sample rows: {len(sample)}") |
| except Exception: |
| pass |
| if not lines: |
| return "(no user tables found)" |
| return "\n".join(lines) |
|
|
|
|
| def validate_sql_safe(sql: str) -> Tuple[bool, str]: |
| if sql.count(";") > 0: |
| if sql.strip().endswith(";"): |
| if sql.strip()[:-1].count(";") > 0: |
| return False, "Multiple statements are not allowed." |
| else: |
| return False, "Multiple statements are not allowed." |
|
|
| upper = re.sub(r"\s+", " ", sql).strip() |
| for kw in FORBIDDEN_KEYWORDS: |
| if re.search(rf"\b{kw}\b", upper): |
| return False, f"Keyword '{kw}' is not allowed." |
|
|
| try: |
| parsed = sqlglot.parse(sql, read='sqlite') |
| except Exception as e: |
| return False, f"SQL parse error: {e}" |
|
|
| if not parsed or len(parsed) != 1: |
| return False, "Exactly one SQL statement is allowed." |
|
|
| stmt = parsed[0] |
| if not isinstance(stmt, exp.Select): |
| return False, "Only SELECT statements are allowed." |
|
|
| for table in stmt.find_all(exp.Table): |
| table_name = table.name.lower() if table.name else "" |
| if table_name in FORBIDDEN_TABLES: |
| return False, f"Access to {table_name} is not allowed." |
|
|
| return True, "OK" |
|
|
| def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]: |
| uri = get_readonly_sqlite_url(db_path) |
| if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE): |
| sql = f"{sql.rstrip(';')} LIMIT {max_rows}" |
|
|
| with sqlite3.connect(uri, uri=True, timeout=timeout) as conn: |
| conn.row_factory = sqlite3.Row |
| cur = conn.cursor() |
| cur.execute(sql) |
| rows = cur.fetchall() |
| if rows: |
| cols = rows[0].keys() |
| data = [list(r) for r in rows] |
| return list(cols), data |
| else: |
| return [], [] |
|
|
|
|
|
|
| custom_prompt = ChatPromptTemplate.from_template(""" |
| Given the following question, return ONLY a valid SQL query in JSON form. |
| |
| Question: {input} |
| Database schema: {table_info} |
| |
| You may sample/preview at most {top_k} rows if you need examples. |
| |
| Respond in this exact JSON format: |
| {{ |
| "sql": "<SQL_QUERY_HERE>" |
| }} |
| """) |
|
|
|
|
| def make_sql_chain(sql_db: SQLDatabase): |
| llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE) |
| chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20) |
| return chain |
|
|
|
|
| def on_upload_database(db_file, state): |
| if db_file is None: |
| return state, "No file provided.", "(no schema)" |
| path = db_file.name |
|
|
| sql_db = SQLDatabase.from_uri(f"sqlite:///{path}") |
|
|
| schema_text = get_schema_preview(path, limit_per_table=0) |
|
|
| chain = make_sql_chain(sql_db) |
|
|
| new_state = { |
| "db_path": path, |
| "sql_db": sql_db, |
| "schema_text": schema_text, |
| "chain": chain, |
| } |
| return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text |
|
|
| def extract_sql_safe(output_text: str) -> str: |
| try: |
| obj = json.loads(output_text) |
| if isinstance(obj, dict) and "sql" in obj: |
| return obj["sql"].strip() |
| except Exception: |
| pass |
| m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE) |
| if m: |
| return m.group(1).strip() |
| return output_text.strip() |
|
|
| def on_generate_query(question , max_rows, state): |
| if not state or not state.get("db_path") or not state.get("chain"): |
| return "Please upload a database first.", "", "" |
| if not question or not question.strip(): |
| return "Please enter a question.", "", "" |
|
|
| try: |
| generated_sql = state["chain"].invoke({"question": question}) |
|
|
| sql = extract_sql_safe(str(generated_sql)) |
|
|
| ok, msg = validate_sql_safe(sql) |
| if not ok: |
| return f"Blocked SQL: {msg}", sql, "" |
|
|
| cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows) |
| if not cols: |
| return f"No rows returned.", sql, "[]" |
|
|
| sample = [dict(zip(cols, r)) for r in rows[:50]] |
| return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2) |
|
|
| except Exception as e: |
| return f"Error: {e}", "", "" |
|
|
|
|
| with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo: |
| gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)") |
| gr.Markdown( |
| "Upload a **SQLite** file, ask a question in natural language, " |
| "and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, " |
| "and (4) show you the results." |
| ) |
|
|
| state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None}) |
|
|
| with gr.Row(): |
| db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"]) |
| upload_status = gr.Textbox(label="upload Status", interactive=False) |
|
|
| schema_box = gr.Accordion("Database schema (preview)", open=False) |
| with schema_box: |
| schema_md = gr.Markdown("(no schema)") |
|
|
| gr.Markdown("---") |
|
|
| with gr.Row(): |
| question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales") |
| with gr.Row(): |
| max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows") |
|
|
| with gr.Row(): |
| run_btn = gr.Button("Generate & Run SQL", variant="primary") |
|
|
| with gr.Row(): |
| status_out = gr.Textbox(label="Status") |
| with gr.Row(): |
| sql_out = gr.Code(label="Generated SQL (validated)") |
| with gr.Row(): |
| result_out = gr.Code(label="Result (JSON sample)") |
|
|
| db_file.change( |
| fn=on_upload_database, |
| inputs=[db_file, state], |
| outputs=[state, upload_status, schema_md], |
| ) |
|
|
| run_btn.click( |
| fn=on_generate_query, |
| inputs=[question, max_row, state], |
| outputs=[status_out, sql_out, result_out], |
| ) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |