Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
a82f275
0
Parent(s):
Initial commit: basic Gradio + Langchain SQL copilot prototype
Browse files- .env.example +11 -0
- .gitignore +4 -0
- .idea/.gitignore +8 -0
- .idea/dataSources.xml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +23 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/nl2sql-copilot-prototype.iml +8 -0
- app.py +234 -0
- config.py +56 -0
- requirements.txt +6 -0
.env.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---- GAPGPT proxy config ----
|
| 2 |
+
# If you’re using a proxy (e.g., GapGPT, Helicone, LocalAI, etc.),
|
| 3 |
+
# set these two values. Otherwise, leave them blank.
|
| 4 |
+
PROXY_API_KEY="your-proxy-token-here"
|
| 5 |
+
PROXY_BASE_URL="https://api.proxy.app/v1"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# ---- optional direct OpenAI config (for fallback) ----
|
| 9 |
+
# These will be used only if proxy variables are not set.
|
| 10 |
+
#OPENAI_API_KEY="your-openai-key-here"
|
| 11 |
+
#OPENAI_BASE_URL="https://api.openai.com/v1"
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
__pycache__/
|
| 3 |
+
.venv/
|
| 4 |
+
.DS_Store
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/dataSources.xml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
| 4 |
+
<data-source source="LOCAL" name="Chinook_Sqlite" uuid="4036a8cf-a7c0-4e84-909d-ada6895430c6">
|
| 5 |
+
<driver-ref>sqlite.xerial</driver-ref>
|
| 6 |
+
<synchronize>true</synchronize>
|
| 7 |
+
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
| 8 |
+
<jdbc-url>jdbc:sqlite:$PROJECT_DIR$/db/Chinook_Sqlite.sqlite</jdbc-url>
|
| 9 |
+
<working-dir>$ProjectFileDir$</working-dir>
|
| 10 |
+
</data-source>
|
| 11 |
+
</component>
|
| 12 |
+
</project>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredPackages">
|
| 6 |
+
<value>
|
| 7 |
+
<list size="10">
|
| 8 |
+
<item index="0" class="java.lang.String" itemvalue="tiktoken" />
|
| 9 |
+
<item index="1" class="java.lang.String" itemvalue="openai" />
|
| 10 |
+
<item index="2" class="java.lang.String" itemvalue="langchain-community" />
|
| 11 |
+
<item index="3" class="java.lang.String" itemvalue="langgraph" />
|
| 12 |
+
<item index="4" class="java.lang.String" itemvalue="pydantic" />
|
| 13 |
+
<item index="5" class="java.lang.String" itemvalue="regex" />
|
| 14 |
+
<item index="6" class="java.lang.String" itemvalue="langchain-openai" />
|
| 15 |
+
<item index="7" class="java.lang.String" itemvalue="langchain" />
|
| 16 |
+
<item index="8" class="java.lang.String" itemvalue="lxml" />
|
| 17 |
+
<item index="9" class="java.lang.String" itemvalue="html5lib" />
|
| 18 |
+
</list>
|
| 19 |
+
</value>
|
| 20 |
+
</option>
|
| 21 |
+
</inspection_tool>
|
| 22 |
+
</profile>
|
| 23 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.13 (ldoce5viewer-master)" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="LLM" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/Text-to-SQL.iml" filepath="$PROJECT_DIR$/.idea/Text-to-SQL.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/nl2sql-copilot-prototype.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="LLM" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
app.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import (
|
| 2 |
+
LLM_MODEL,
|
| 3 |
+
LLM_TEMPERATURE,
|
| 4 |
+
FORBIDDEN_KEYWORDS,
|
| 5 |
+
FORBIDDEN_TABLES
|
| 6 |
+
)
|
| 7 |
+
import os
|
| 8 |
+
import sqlite3
|
| 9 |
+
import json
|
| 10 |
+
import re
|
| 11 |
+
from typing import Optional, Tuple, List
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import sqlglot
|
| 15 |
+
from sqlglot import exp
|
| 16 |
+
|
| 17 |
+
from langchain_openai import ChatOpenAI
|
| 18 |
+
from langchain_community.utilities import SQLDatabase
|
| 19 |
+
from langchain.chains import create_sql_query_chain
|
| 20 |
+
from langchain.prompts import ChatPromptTemplate
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_readonly_sqlite_url(db_path: str) -> str:
|
| 24 |
+
return f"file:{db_path}?mode=ro&uri=true"
|
| 25 |
+
|
| 26 |
+
def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str:
|
| 27 |
+
uri = get_readonly_sqlite_url(db_path)
|
| 28 |
+
with sqlite3.connect(uri, uri=True, timeout=3) as conn:
|
| 29 |
+
conn.row_factory = sqlite3.Row
|
| 30 |
+
cur = conn.cursor()
|
| 31 |
+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
|
| 32 |
+
tables = [r["name"] for r in cur.fetchall()]
|
| 33 |
+
lines = []
|
| 34 |
+
for t in tables:
|
| 35 |
+
# skip SQLite internals
|
| 36 |
+
if t in FORBIDDEN_TABLES:
|
| 37 |
+
continue
|
| 38 |
+
cur.execute(f"PRAGMA table_info({t});")
|
| 39 |
+
cols = cur.fetchall()
|
| 40 |
+
col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols])
|
| 41 |
+
lines.append(f"- {t} ({col_line})")
|
| 42 |
+
if limit_per_table > 0:
|
| 43 |
+
try:
|
| 44 |
+
cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};")
|
| 45 |
+
sample = cur.fetchall()
|
| 46 |
+
if sample:
|
| 47 |
+
lines.append(f" sample rows: {len(sample)}")
|
| 48 |
+
except Exception:
|
| 49 |
+
pass
|
| 50 |
+
if not lines:
|
| 51 |
+
return "(no user tables found)"
|
| 52 |
+
return "\n".join(lines)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def validate_sql_safe(sql: str) -> Tuple[bool, str]:
|
| 56 |
+
if sql.count(";") > 0:
|
| 57 |
+
if sql.strip().endswith(";"):
|
| 58 |
+
if sql.strip()[:-1].count(";") > 0:
|
| 59 |
+
return False, "Multiple statements are not allowed."
|
| 60 |
+
else:
|
| 61 |
+
return False, "Multiple statements are not allowed."
|
| 62 |
+
|
| 63 |
+
upper = re.sub(r"\s+", " ", sql).strip()
|
| 64 |
+
for kw in FORBIDDEN_KEYWORDS:
|
| 65 |
+
if re.search(rf"\b{kw}\b", upper):
|
| 66 |
+
return False, f"Keyword '{kw}' is not allowed."
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
parsed = sqlglot.parse(sql, read='sqlite')
|
| 70 |
+
except Exception as e:
|
| 71 |
+
return False, f"SQL parse error: {e}"
|
| 72 |
+
|
| 73 |
+
if not parsed or len(parsed) != 1:
|
| 74 |
+
return False, "Exactly one SQL statement is allowed."
|
| 75 |
+
|
| 76 |
+
stmt = parsed[0]
|
| 77 |
+
if not isinstance(stmt, exp.Select):
|
| 78 |
+
return False, "Only SELECT statements are allowed."
|
| 79 |
+
|
| 80 |
+
for table in stmt.find_all(exp.Table):
|
| 81 |
+
table_name = table.name.lower() if table.name else ""
|
| 82 |
+
if table_name in FORBIDDEN_TABLES:
|
| 83 |
+
return False, f"Access to {table_name} is not allowed."
|
| 84 |
+
|
| 85 |
+
return True, "OK"
|
| 86 |
+
|
| 87 |
+
def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]:
|
| 88 |
+
uri = get_readonly_sqlite_url(db_path)
|
| 89 |
+
if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE):
|
| 90 |
+
sql = f"{sql.rstrip(';')} LIMIT {max_rows}"
|
| 91 |
+
|
| 92 |
+
with sqlite3.connect(uri, uri=True, timeout=timeout) as conn:
|
| 93 |
+
conn.row_factory = sqlite3.Row
|
| 94 |
+
cur = conn.cursor()
|
| 95 |
+
cur.execute(sql)
|
| 96 |
+
rows = cur.fetchall()
|
| 97 |
+
if rows:
|
| 98 |
+
cols = rows[0].keys()
|
| 99 |
+
data = [list(r) for r in rows]
|
| 100 |
+
return list(cols), data
|
| 101 |
+
else:
|
| 102 |
+
return [], []
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
custom_prompt = ChatPromptTemplate.from_template("""
|
| 107 |
+
Given the following question, return ONLY a valid SQL query in JSON form.
|
| 108 |
+
|
| 109 |
+
Question: {input}
|
| 110 |
+
Database schema: {table_info}
|
| 111 |
+
|
| 112 |
+
You may sample/preview at most {top_k} rows if you need examples.
|
| 113 |
+
|
| 114 |
+
Respond in this exact JSON format:
|
| 115 |
+
{{
|
| 116 |
+
"sql": "<SQL_QUERY_HERE>"
|
| 117 |
+
}}
|
| 118 |
+
""")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def make_sql_chain(sql_db: SQLDatabase):
|
| 122 |
+
llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
|
| 123 |
+
chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
|
| 124 |
+
return chain
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def on_upload_database(db_file, state):
|
| 128 |
+
if db_file is None:
|
| 129 |
+
return state, "No file provided.", "(no schema)"
|
| 130 |
+
path = db_file.name
|
| 131 |
+
|
| 132 |
+
sql_db = SQLDatabase.from_uri(f"sqlite:///{path}")
|
| 133 |
+
|
| 134 |
+
schema_text = get_schema_preview(path, limit_per_table=0)
|
| 135 |
+
|
| 136 |
+
chain = make_sql_chain(sql_db)
|
| 137 |
+
|
| 138 |
+
new_state = {
|
| 139 |
+
"db_path": path,
|
| 140 |
+
"sql_db": sql_db,
|
| 141 |
+
"schema_text": schema_text,
|
| 142 |
+
"chain": chain,
|
| 143 |
+
}
|
| 144 |
+
return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text
|
| 145 |
+
|
| 146 |
+
def extract_sql_safe(output_text: str) -> str:
|
| 147 |
+
try:
|
| 148 |
+
obj = json.loads(output_text)
|
| 149 |
+
if isinstance(obj, dict) and "sql" in obj:
|
| 150 |
+
return obj["sql"].strip()
|
| 151 |
+
except Exception:
|
| 152 |
+
pass
|
| 153 |
+
m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE)
|
| 154 |
+
if m:
|
| 155 |
+
return m.group(1).strip()
|
| 156 |
+
return output_text.strip()
|
| 157 |
+
|
| 158 |
+
def on_generate_query(question , max_rows, state):
|
| 159 |
+
if not state or not state.get("db_path") or not state.get("chain"):
|
| 160 |
+
return "Please upload a database first.", "", ""
|
| 161 |
+
if not question or not question.strip():
|
| 162 |
+
return "Please enter a question.", "", ""
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
generated_sql = state["chain"].invoke({"question": question})
|
| 166 |
+
|
| 167 |
+
sql = extract_sql_safe(str(generated_sql))
|
| 168 |
+
|
| 169 |
+
ok, msg = validate_sql_safe(sql)
|
| 170 |
+
if not ok:
|
| 171 |
+
return f"Blocked SQL: {msg}", sql, ""
|
| 172 |
+
|
| 173 |
+
cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows)
|
| 174 |
+
if not cols:
|
| 175 |
+
return f"No rows returned.", sql, "[]"
|
| 176 |
+
|
| 177 |
+
sample = [dict(zip(cols, r)) for r in rows[:50]]
|
| 178 |
+
return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2)
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
return f"Error: {e}", "", ""
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo:
|
| 185 |
+
gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)")
|
| 186 |
+
gr.Markdown(
|
| 187 |
+
"Upload a **SQLite** file, ask a question in natural language, "
|
| 188 |
+
"and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, "
|
| 189 |
+
"and (4) show you the results."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None})
|
| 193 |
+
|
| 194 |
+
with gr.Row():
|
| 195 |
+
db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"])
|
| 196 |
+
upload_status = gr.Textbox(label="upload Status", interactive=False)
|
| 197 |
+
|
| 198 |
+
schema_box = gr.Accordion("Database schema (preview)", open=False)
|
| 199 |
+
with schema_box:
|
| 200 |
+
schema_md = gr.Markdown("(no schema)")
|
| 201 |
+
|
| 202 |
+
gr.Markdown("---")
|
| 203 |
+
|
| 204 |
+
with gr.Row():
|
| 205 |
+
question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales")
|
| 206 |
+
with gr.Row():
|
| 207 |
+
max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows")
|
| 208 |
+
|
| 209 |
+
with gr.Row():
|
| 210 |
+
run_btn = gr.Button("Generate & Run SQL", variant="primary")
|
| 211 |
+
|
| 212 |
+
with gr.Row():
|
| 213 |
+
status_out = gr.Textbox(label="Status")
|
| 214 |
+
with gr.Row():
|
| 215 |
+
sql_out = gr.Code(label="Generated SQL (validated)")
|
| 216 |
+
with gr.Row():
|
| 217 |
+
result_out = gr.Code(label="Result (JSON sample)")
|
| 218 |
+
|
| 219 |
+
db_file.change(
|
| 220 |
+
fn=on_upload_database,
|
| 221 |
+
inputs=[db_file, state],
|
| 222 |
+
outputs=[state, upload_status, schema_md],
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
run_btn.click(
|
| 226 |
+
fn=on_generate_query,
|
| 227 |
+
inputs=[question, max_row, state],
|
| 228 |
+
outputs=[status_out, sql_out, result_out],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
demo.launch()
|
config.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
# ----------------------------
|
| 5 |
+
# Load .env
|
| 6 |
+
# ----------------------------
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_env_var(name: str, required: bool = True, default: str | None = None) -> str | None:
|
| 11 |
+
"""Safely get an environment variable or raise a clear error if missing."""
|
| 12 |
+
value = os.getenv(name, default)
|
| 13 |
+
if required and not value:
|
| 14 |
+
raise ValueError(f"Missing required environment variable: {name}")
|
| 15 |
+
return value
|
| 16 |
+
|
| 17 |
+
# ----------------------------
|
| 18 |
+
# Detect which mode we're in
|
| 19 |
+
# ----------------------------
|
| 20 |
+
PROXY_TOKEN = os.getenv("PROXY_API_KEY")
|
| 21 |
+
PROXY_BASE_URL = os.getenv("PROXY_BASE_URL")
|
| 22 |
+
|
| 23 |
+
if PROXY_TOKEN and PROXY_BASE_URL:
|
| 24 |
+
MODE = "proxy"
|
| 25 |
+
os.environ["OPENAI_API_KEY"] = PROXY_TOKEN
|
| 26 |
+
os.environ["OPENAI_BASE_URL"] = PROXY_BASE_URL
|
| 27 |
+
else:
|
| 28 |
+
MODE = "direct"
|
| 29 |
+
os.environ["OPENAI_API_KEY"] = get_env_var("OPENAI_API_KEY")
|
| 30 |
+
if base_url := os.getenv("OPENAI_BASE_URL"):
|
| 31 |
+
os.environ["OPENAI_BASE_URL"] = base_url
|
| 32 |
+
|
| 33 |
+
# ----------------------------
|
| 34 |
+
# Exported values
|
| 35 |
+
# ----------------------------
|
| 36 |
+
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
| 37 |
+
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 38 |
+
|
| 39 |
+
# ----------------------------
|
| 40 |
+
# Optional logging for clarity
|
| 41 |
+
# ----------------------------
|
| 42 |
+
print(f"[config] Mode: {MODE.upper()} | Base URL: {OPENAI_BASE_URL}")
|
| 43 |
+
|
| 44 |
+
LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # or gpt-4o, gpt-4o-mini
|
| 45 |
+
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
|
| 46 |
+
|
| 47 |
+
# Hard blocks (defense-in-depth)
|
| 48 |
+
FORBIDDEN_KEYWORDS = {
|
| 49 |
+
"ATTACH", "PRAGMA",
|
| 50 |
+
"CREATE", "DROP", "ALTER", "VACUUM", "REINDEX", "TRIGGER",
|
| 51 |
+
"INSERT", "UPDATE", "DELETE", "REPLACE",
|
| 52 |
+
"GRANT", "REVOKE",
|
| 53 |
+
"BEGIN", "END", "COMMIT", "ROLLBACK",
|
| 54 |
+
"DETACH",
|
| 55 |
+
}
|
| 56 |
+
FORBIDDEN_TABLES = {"sqlite_master", "sqlite_temp_master"}
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
langchain
|
| 3 |
+
langchain-openai
|
| 4 |
+
sqlglot
|
| 5 |
+
openai
|
| 6 |
+
python-dotenv
|