File size: 7,451 Bytes
a82f275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from config import (
    LLM_MODEL,
    LLM_TEMPERATURE,
    FORBIDDEN_KEYWORDS,
    FORBIDDEN_TABLES
)
import os
import sqlite3
import json
import re
from typing import Optional, Tuple, List

import gradio as gr
import sqlglot
from sqlglot import exp

from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain.prompts import ChatPromptTemplate


def get_readonly_sqlite_url(db_path: str) -> str:
    return f"file:{db_path}?mode=ro&uri=true"

def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str:
    uri = get_readonly_sqlite_url(db_path)
    with sqlite3.connect(uri, uri=True, timeout=3) as conn:
        conn.row_factory = sqlite3.Row
        cur = conn.cursor()
        cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
        tables = [r["name"] for r in cur.fetchall()]
        lines = []
        for t in tables:
            # skip SQLite internals
            if t in FORBIDDEN_TABLES:
                continue
            cur.execute(f"PRAGMA table_info({t});")
            cols = cur.fetchall()
            col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols])
            lines.append(f"- {t} ({col_line})")
            if limit_per_table > 0:
                try:
                    cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};")
                    sample = cur.fetchall()
                    if sample:
                        lines.append(f"  sample rows: {len(sample)}")
                except Exception:
                    pass
        if not lines:
            return "(no user tables found)"
        return "\n".join(lines)


def validate_sql_safe(sql: str) -> Tuple[bool, str]:
    if sql.count(";") > 0:
        if sql.strip().endswith(";"):
            if sql.strip()[:-1].count(";") > 0:
                return False, "Multiple statements are not allowed."
        else:
            return False, "Multiple statements are not allowed."

    upper = re.sub(r"\s+", " ", sql).strip()
    for kw in FORBIDDEN_KEYWORDS:
        if re.search(rf"\b{kw}\b", upper):
            return False, f"Keyword '{kw}' is not allowed."

    try:
        parsed = sqlglot.parse(sql, read='sqlite')
    except Exception as e:
        return False, f"SQL parse error: {e}"

    if not parsed or len(parsed) != 1:
        return False, "Exactly one SQL statement is allowed."

    stmt = parsed[0]
    if not isinstance(stmt, exp.Select):
        return False, "Only SELECT statements are allowed."

    for table in stmt.find_all(exp.Table):
        table_name = table.name.lower() if table.name else ""
        if table_name in FORBIDDEN_TABLES:
            return False, f"Access to {table_name} is not allowed."

    return True, "OK"

def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]:
    uri = get_readonly_sqlite_url(db_path)
    if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE):
        sql = f"{sql.rstrip(';')} LIMIT {max_rows}"

    with sqlite3.connect(uri, uri=True, timeout=timeout) as conn:
        conn.row_factory = sqlite3.Row
        cur = conn.cursor()
        cur.execute(sql)
        rows = cur.fetchall()
        if rows:
            cols = rows[0].keys()
            data = [list(r) for r in rows]
            return list(cols), data
        else:
            return [], []



custom_prompt = ChatPromptTemplate.from_template("""
Given the following question, return ONLY a valid SQL query in JSON form.

Question: {input}
Database schema: {table_info}

You may sample/preview at most {top_k} rows if you need examples.

Respond in this exact JSON format:
{{
  "sql": "<SQL_QUERY_HERE>"
}}
""")


def make_sql_chain(sql_db: SQLDatabase):
    llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
    chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
    return chain


def on_upload_database(db_file, state):
    if db_file is None:
        return state, "No file provided.", "(no schema)"
    path = db_file.name

    sql_db = SQLDatabase.from_uri(f"sqlite:///{path}")

    schema_text = get_schema_preview(path, limit_per_table=0)

    chain = make_sql_chain(sql_db)

    new_state = {
        "db_path": path,
        "sql_db": sql_db,
        "schema_text": schema_text,
        "chain": chain,
    }
    return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text

def extract_sql_safe(output_text: str) -> str:
    try:
        obj = json.loads(output_text)
        if isinstance(obj, dict) and "sql" in obj:
            return obj["sql"].strip()
    except Exception:
        pass
    m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE)
    if m:
        return m.group(1).strip()
    return output_text.strip()

def on_generate_query(question , max_rows, state):
    if not state or not state.get("db_path") or not state.get("chain"):
        return "Please upload a database first.", "", ""
    if not question or not question.strip():
        return "Please enter a question.", "", ""

    try:
        generated_sql = state["chain"].invoke({"question": question})

        sql = extract_sql_safe(str(generated_sql))

        ok, msg = validate_sql_safe(sql)
        if not ok:
            return f"Blocked SQL: {msg}", sql, ""

        cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows)
        if not cols:
            return f"No rows returned.", sql, "[]"

        sample = [dict(zip(cols, r)) for r in rows[:50]]
        return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2)

    except Exception as e:
        return f"Error: {e}", "", ""


with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo:
    gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)")
    gr.Markdown(
        "Upload a **SQLite** file, ask a question in natural language, "
        "and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, "
        "and (4) show you the results."
    )

    state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None})

    with gr.Row():
        db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"])
        upload_status = gr.Textbox(label="upload Status", interactive=False)

    schema_box = gr.Accordion("Database schema (preview)", open=False)
    with schema_box:
        schema_md = gr.Markdown("(no schema)")

    gr.Markdown("---")

    with gr.Row():
        question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales")
    with gr.Row():
        max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows")

    with gr.Row():
        run_btn = gr.Button("Generate & Run SQL", variant="primary")

    with gr.Row():
        status_out = gr.Textbox(label="Status")
    with gr.Row():
        sql_out = gr.Code(label="Generated SQL (validated)")
    with gr.Row():
        result_out = gr.Code(label="Result (JSON sample)")

    db_file.change(
        fn=on_upload_database,
        inputs=[db_file, state],
        outputs=[state, upload_status, schema_md],
    )

    run_btn.click(
        fn=on_generate_query,
        inputs=[question, max_row, state],
        outputs=[status_out, sql_out, result_out],
    )



if __name__ == "__main__":
    demo.launch()