Melika Kheirieh commited on
Commit
a82f275
·
0 Parent(s):

Initial commit: basic Gradio + Langchain SQL copilot prototype

Browse files
.env.example ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---- GAPGPT proxy config ----
2
+ # If you’re using a proxy (e.g., GapGPT, Helicone, LocalAI, etc.),
3
+ # set these two values. Otherwise, leave them blank.
4
+ PROXY_API_KEY="your-proxy-token-here"
5
+ PROXY_BASE_URL="https://api.proxy.app/v1"
6
+
7
+
8
+ # ---- optional direct OpenAI config (for fallback) ----
9
+ # These will be used only if proxy variables are not set.
10
+ #OPENAI_API_KEY="your-openai-key-here"
11
+ #OPENAI_BASE_URL="https://api.openai.com/v1"
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ .venv/
4
+ .DS_Store
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/dataSources.xml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="DataSourceManagerImpl" format="xml" multifile-model="true">
4
+ <data-source source="LOCAL" name="Chinook_Sqlite" uuid="4036a8cf-a7c0-4e84-909d-ada6895430c6">
5
+ <driver-ref>sqlite.xerial</driver-ref>
6
+ <synchronize>true</synchronize>
7
+ <jdbc-driver>org.sqlite.JDBC</jdbc-driver>
8
+ <jdbc-url>jdbc:sqlite:$PROJECT_DIR$/db/Chinook_Sqlite.sqlite</jdbc-url>
9
+ <working-dir>$ProjectFileDir$</working-dir>
10
+ </data-source>
11
+ </component>
12
+ </project>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="10">
8
+ <item index="0" class="java.lang.String" itemvalue="tiktoken" />
9
+ <item index="1" class="java.lang.String" itemvalue="openai" />
10
+ <item index="2" class="java.lang.String" itemvalue="langchain-community" />
11
+ <item index="3" class="java.lang.String" itemvalue="langgraph" />
12
+ <item index="4" class="java.lang.String" itemvalue="pydantic" />
13
+ <item index="5" class="java.lang.String" itemvalue="regex" />
14
+ <item index="6" class="java.lang.String" itemvalue="langchain-openai" />
15
+ <item index="7" class="java.lang.String" itemvalue="langchain" />
16
+ <item index="8" class="java.lang.String" itemvalue="lxml" />
17
+ <item index="9" class="java.lang.String" itemvalue="html5lib" />
18
+ </list>
19
+ </value>
20
+ </option>
21
+ </inspection_tool>
22
+ </profile>
23
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.13 (ldoce5viewer-master)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="LLM" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Text-to-SQL.iml" filepath="$PROJECT_DIR$/.idea/Text-to-SQL.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/nl2sql-copilot-prototype.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="LLM" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import (
2
+ LLM_MODEL,
3
+ LLM_TEMPERATURE,
4
+ FORBIDDEN_KEYWORDS,
5
+ FORBIDDEN_TABLES
6
+ )
7
+ import os
8
+ import sqlite3
9
+ import json
10
+ import re
11
+ from typing import Optional, Tuple, List
12
+
13
+ import gradio as gr
14
+ import sqlglot
15
+ from sqlglot import exp
16
+
17
+ from langchain_openai import ChatOpenAI
18
+ from langchain_community.utilities import SQLDatabase
19
+ from langchain.chains import create_sql_query_chain
20
+ from langchain.prompts import ChatPromptTemplate
21
+
22
+
23
+ def get_readonly_sqlite_url(db_path: str) -> str:
24
+ return f"file:{db_path}?mode=ro&uri=true"
25
+
26
+ def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str:
27
+ uri = get_readonly_sqlite_url(db_path)
28
+ with sqlite3.connect(uri, uri=True, timeout=3) as conn:
29
+ conn.row_factory = sqlite3.Row
30
+ cur = conn.cursor()
31
+ cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
32
+ tables = [r["name"] for r in cur.fetchall()]
33
+ lines = []
34
+ for t in tables:
35
+ # skip SQLite internals
36
+ if t in FORBIDDEN_TABLES:
37
+ continue
38
+ cur.execute(f"PRAGMA table_info({t});")
39
+ cols = cur.fetchall()
40
+ col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols])
41
+ lines.append(f"- {t} ({col_line})")
42
+ if limit_per_table > 0:
43
+ try:
44
+ cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};")
45
+ sample = cur.fetchall()
46
+ if sample:
47
+ lines.append(f" sample rows: {len(sample)}")
48
+ except Exception:
49
+ pass
50
+ if not lines:
51
+ return "(no user tables found)"
52
+ return "\n".join(lines)
53
+
54
+
55
+ def validate_sql_safe(sql: str) -> Tuple[bool, str]:
56
+ if sql.count(";") > 0:
57
+ if sql.strip().endswith(";"):
58
+ if sql.strip()[:-1].count(";") > 0:
59
+ return False, "Multiple statements are not allowed."
60
+ else:
61
+ return False, "Multiple statements are not allowed."
62
+
63
+ upper = re.sub(r"\s+", " ", sql).strip()
64
+ for kw in FORBIDDEN_KEYWORDS:
65
+ if re.search(rf"\b{kw}\b", upper):
66
+ return False, f"Keyword '{kw}' is not allowed."
67
+
68
+ try:
69
+ parsed = sqlglot.parse(sql, read='sqlite')
70
+ except Exception as e:
71
+ return False, f"SQL parse error: {e}"
72
+
73
+ if not parsed or len(parsed) != 1:
74
+ return False, "Exactly one SQL statement is allowed."
75
+
76
+ stmt = parsed[0]
77
+ if not isinstance(stmt, exp.Select):
78
+ return False, "Only SELECT statements are allowed."
79
+
80
+ for table in stmt.find_all(exp.Table):
81
+ table_name = table.name.lower() if table.name else ""
82
+ if table_name in FORBIDDEN_TABLES:
83
+ return False, f"Access to {table_name} is not allowed."
84
+
85
+ return True, "OK"
86
+
87
+ def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]:
88
+ uri = get_readonly_sqlite_url(db_path)
89
+ if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE):
90
+ sql = f"{sql.rstrip(';')} LIMIT {max_rows}"
91
+
92
+ with sqlite3.connect(uri, uri=True, timeout=timeout) as conn:
93
+ conn.row_factory = sqlite3.Row
94
+ cur = conn.cursor()
95
+ cur.execute(sql)
96
+ rows = cur.fetchall()
97
+ if rows:
98
+ cols = rows[0].keys()
99
+ data = [list(r) for r in rows]
100
+ return list(cols), data
101
+ else:
102
+ return [], []
103
+
104
+
105
+
106
+ custom_prompt = ChatPromptTemplate.from_template("""
107
+ Given the following question, return ONLY a valid SQL query in JSON form.
108
+
109
+ Question: {input}
110
+ Database schema: {table_info}
111
+
112
+ You may sample/preview at most {top_k} rows if you need examples.
113
+
114
+ Respond in this exact JSON format:
115
+ {{
116
+ "sql": "<SQL_QUERY_HERE>"
117
+ }}
118
+ """)
119
+
120
+
121
+ def make_sql_chain(sql_db: SQLDatabase):
122
+ llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
123
+ chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
124
+ return chain
125
+
126
+
127
+ def on_upload_database(db_file, state):
128
+ if db_file is None:
129
+ return state, "No file provided.", "(no schema)"
130
+ path = db_file.name
131
+
132
+ sql_db = SQLDatabase.from_uri(f"sqlite:///{path}")
133
+
134
+ schema_text = get_schema_preview(path, limit_per_table=0)
135
+
136
+ chain = make_sql_chain(sql_db)
137
+
138
+ new_state = {
139
+ "db_path": path,
140
+ "sql_db": sql_db,
141
+ "schema_text": schema_text,
142
+ "chain": chain,
143
+ }
144
+ return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text
145
+
146
+ def extract_sql_safe(output_text: str) -> str:
147
+ try:
148
+ obj = json.loads(output_text)
149
+ if isinstance(obj, dict) and "sql" in obj:
150
+ return obj["sql"].strip()
151
+ except Exception:
152
+ pass
153
+ m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE)
154
+ if m:
155
+ return m.group(1).strip()
156
+ return output_text.strip()
157
+
158
+ def on_generate_query(question , max_rows, state):
159
+ if not state or not state.get("db_path") or not state.get("chain"):
160
+ return "Please upload a database first.", "", ""
161
+ if not question or not question.strip():
162
+ return "Please enter a question.", "", ""
163
+
164
+ try:
165
+ generated_sql = state["chain"].invoke({"question": question})
166
+
167
+ sql = extract_sql_safe(str(generated_sql))
168
+
169
+ ok, msg = validate_sql_safe(sql)
170
+ if not ok:
171
+ return f"Blocked SQL: {msg}", sql, ""
172
+
173
+ cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows)
174
+ if not cols:
175
+ return f"No rows returned.", sql, "[]"
176
+
177
+ sample = [dict(zip(cols, r)) for r in rows[:50]]
178
+ return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2)
179
+
180
+ except Exception as e:
181
+ return f"Error: {e}", "", ""
182
+
183
+
184
+ with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo:
185
+ gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)")
186
+ gr.Markdown(
187
+ "Upload a **SQLite** file, ask a question in natural language, "
188
+ "and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, "
189
+ "and (4) show you the results."
190
+ )
191
+
192
+ state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None})
193
+
194
+ with gr.Row():
195
+ db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"])
196
+ upload_status = gr.Textbox(label="upload Status", interactive=False)
197
+
198
+ schema_box = gr.Accordion("Database schema (preview)", open=False)
199
+ with schema_box:
200
+ schema_md = gr.Markdown("(no schema)")
201
+
202
+ gr.Markdown("---")
203
+
204
+ with gr.Row():
205
+ question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales")
206
+ with gr.Row():
207
+ max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows")
208
+
209
+ with gr.Row():
210
+ run_btn = gr.Button("Generate & Run SQL", variant="primary")
211
+
212
+ with gr.Row():
213
+ status_out = gr.Textbox(label="Status")
214
+ with gr.Row():
215
+ sql_out = gr.Code(label="Generated SQL (validated)")
216
+ with gr.Row():
217
+ result_out = gr.Code(label="Result (JSON sample)")
218
+
219
+ db_file.change(
220
+ fn=on_upload_database,
221
+ inputs=[db_file, state],
222
+ outputs=[state, upload_status, schema_md],
223
+ )
224
+
225
+ run_btn.click(
226
+ fn=on_generate_query,
227
+ inputs=[question, max_row, state],
228
+ outputs=[status_out, sql_out, result_out],
229
+ )
230
+
231
+
232
+
233
+ if __name__ == "__main__":
234
+ demo.launch()
config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ # ----------------------------
5
+ # Load .env
6
+ # ----------------------------
7
+ load_dotenv()
8
+
9
+
10
+ def get_env_var(name: str, required: bool = True, default: str | None = None) -> str | None:
11
+ """Safely get an environment variable or raise a clear error if missing."""
12
+ value = os.getenv(name, default)
13
+ if required and not value:
14
+ raise ValueError(f"Missing required environment variable: {name}")
15
+ return value
16
+
17
+ # ----------------------------
18
+ # Detect which mode we're in
19
+ # ----------------------------
20
+ PROXY_TOKEN = os.getenv("PROXY_API_KEY")
21
+ PROXY_BASE_URL = os.getenv("PROXY_BASE_URL")
22
+
23
+ if PROXY_TOKEN and PROXY_BASE_URL:
24
+ MODE = "proxy"
25
+ os.environ["OPENAI_API_KEY"] = PROXY_TOKEN
26
+ os.environ["OPENAI_BASE_URL"] = PROXY_BASE_URL
27
+ else:
28
+ MODE = "direct"
29
+ os.environ["OPENAI_API_KEY"] = get_env_var("OPENAI_API_KEY")
30
+ if base_url := os.getenv("OPENAI_BASE_URL"):
31
+ os.environ["OPENAI_BASE_URL"] = base_url
32
+
33
+ # ----------------------------
34
+ # Exported values
35
+ # ----------------------------
36
+ OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
37
+ OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
38
+
39
+ # ----------------------------
40
+ # Optional logging for clarity
41
+ # ----------------------------
42
+ print(f"[config] Mode: {MODE.upper()} | Base URL: {OPENAI_BASE_URL}")
43
+
44
+ LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # or gpt-4o, gpt-4o-mini
45
+ LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
46
+
47
+ # Hard blocks (defense-in-depth)
48
+ FORBIDDEN_KEYWORDS = {
49
+ "ATTACH", "PRAGMA",
50
+ "CREATE", "DROP", "ALTER", "VACUUM", "REINDEX", "TRIGGER",
51
+ "INSERT", "UPDATE", "DELETE", "REPLACE",
52
+ "GRANT", "REVOKE",
53
+ "BEGIN", "END", "COMMIT", "ROLLBACK",
54
+ "DETACH",
55
+ }
56
+ FORBIDDEN_TABLES = {"sqlite_master", "sqlite_temp_master"}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ langchain
3
+ langchain-openai
4
+ sqlglot
5
+ openai
6
+ python-dotenv