Melika Kheirieh commited on
Commit
b568b83
·
1 Parent(s): a0aff5b

feat(demo): add initial Gradio UI with optional SQLite upload

Browse files
Files changed (4) hide show
  1. app/main.py +2 -1
  2. app/routers/nl2sql.py +141 -34
  3. demo/app.py +108 -0
  4. requirements.txt +3 -1
app/main.py CHANGED
@@ -1,8 +1,9 @@
1
  from dotenv import load_dotenv
 
 
2
  from fastapi import FastAPI
3
  from app.routers import nl2sql
4
 
5
- load_dotenv()
6
 
7
 
8
  app = FastAPI(
 
1
  from dotenv import load_dotenv
2
+ load_dotenv()
3
+
4
  from fastapi import FastAPI
5
  from app.routers import nl2sql
6
 
 
7
 
8
 
9
  app = FastAPI(
app/routers/nl2sql.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import asdict, is_dataclass
2
- from fastapi import APIRouter, HTTPException
3
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
4
  from nl2sql.pipeline import Pipeline, FinalResult
5
  from nl2sql.ambiguity_detector import AmbiguityDetector
@@ -12,75 +12,182 @@ from nl2sql.verifier import Verifier
12
  from nl2sql.repair import Repair
13
  from adapters.db.sqlite_adapter import SQLiteAdapter
14
  from adapters.db.postgres_adapter import PostgresAdapter
15
- import os
16
- from typing import Union
17
 
 
 
 
 
18
 
19
  router = APIRouter(prefix="/nl2sql")
20
 
21
-
22
- _db: Union[PostgresAdapter, SQLiteAdapter]
23
- if os.getenv("DB_MODE", "sqlite") == "postgres":
24
- _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
25
- else:
26
- _db = SQLiteAdapter("data/chinook.db")
27
-
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_llm():
30
  return OpenAIProvider()
31
 
32
-
33
- # _db = SQLiteAdapter("data/chinook.db")
34
- _executor = Executor(_db)
 
35
  _verifier = Verifier()
36
  _repair = Repair(get_llm())
37
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- _pipeline = Pipeline(
40
- detector=AmbiguityDetector(),
41
- planner=Planner(get_llm()),
42
- generator=Generator(get_llm()),
43
- safety=Safety(),
44
- executor=_executor,
45
- verifier=_verifier,
46
- repair=_repair,
47
- )
48
-
49
-
50
  def _to_dict(obj):
51
- """Helper: safely convert dataclass → dict."""
52
  return asdict(obj) if is_dataclass(obj) else obj
53
 
54
-
55
  def _round_trace(t: dict) -> dict:
 
56
  if t.get("cost_usd") is not None:
57
  t["cost_usd"] = round(t["cost_usd"], 6)
58
  if t.get("duration_ms") is not None:
59
  t["duration_ms"] = round(t["duration_ms"], 2)
60
  return t
61
 
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @router.post("", name="nl2sql_handler")
64
  def nl2sql_handler(request: NL2SQLRequest):
65
- result = _pipeline.run(
 
 
 
 
 
 
 
 
 
 
 
66
  user_query=request.query,
67
- schema_preview=request.schema_preview,
68
  )
69
 
70
- # --- Ensure result type ---
71
  if not isinstance(result, FinalResult):
72
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
73
 
74
- # --- Handle ambiguity ---
75
  if result.ambiguous and result.questions:
76
  return ClarifyResponse(ambiguous=True, questions=result.questions)
77
 
78
- # --- Handle error ---
79
  if not result.ok or result.error:
80
  detail = "; ".join(result.details or ["Unknown error"])
81
  raise HTTPException(status_code=400, detail=detail)
82
 
83
- # --- Success case ---
84
  traces = [_round_trace(t) for t in (result.traces or [])]
85
  return NL2SQLResponse(
86
  ambiguous=False,
 
1
  from dataclasses import asdict, is_dataclass
2
+ from fastapi import APIRouter, HTTPException, UploadFile, File
3
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
4
  from nl2sql.pipeline import Pipeline, FinalResult
5
  from nl2sql.ambiguity_detector import AmbiguityDetector
 
12
  from nl2sql.repair import Repair
13
  from adapters.db.sqlite_adapter import SQLiteAdapter
14
  from adapters.db.postgres_adapter import PostgresAdapter
 
 
15
 
16
+ import os
17
+ import time
18
+ import uuid
19
+ from typing import Union, Optional, Dict
20
 
21
  router = APIRouter(prefix="/nl2sql")
22
 
23
+ # -------------------------------
24
+ # Runtime DB registry (for uploaded SQLite files)
25
+ # Files are stored under /tmp, mapped by a short-lived db_id
26
+ # -------------------------------
27
+ _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
28
+ _DB_TTL_SECONDS = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
29
+ os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
30
+
31
+ # In-memory map: db_id -> {"path": str, "ts": float}
32
+ _DB_MAP: Dict[str, Dict[str, object]] = {}
33
+
34
+ # -------------------------------
35
+ # Default DB resolution
36
+ # -------------------------------
37
+ DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
38
+ POSTGRES_DSN = os.getenv("POSTGRES_DSN")
39
+ DEFAULT_SQLITE_DB = os.getenv("DEFAULT_SQLITE_DB", "data/chinook.db") # keep your current default
40
+
41
+ def _cleanup_db_map() -> None:
42
+ """Remove expired uploaded DB files (best-effort)."""
43
+ now = time.time()
44
+ expired = [k for k, v in _DB_MAP.items() if now - float(v.get("ts", 0)) > _DB_TTL_SECONDS]
45
+ for k in expired:
46
+ path = _DB_MAP[k].get("path")
47
+ try:
48
+ if isinstance(path, str) and os.path.exists(path):
49
+ os.remove(path)
50
+ except Exception:
51
+ pass
52
+ _DB_MAP.pop(k, None)
53
+
54
+ def _resolve_sqlite_path(db_id: Optional[str]) -> str:
55
+ """Resolve a SQLite file path from db_id or fallback to default."""
56
+ _cleanup_db_map()
57
+ if db_id and db_id in _DB_MAP:
58
+ return str(_DB_MAP[db_id]["path"])
59
+ return DEFAULT_SQLITE_DB
60
+
61
+ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
62
+ """
63
+ Build a DB adapter for this request.
64
+ - In postgres mode: always PostgresAdapter(POSTGRES_DSN).
65
+ - In sqlite mode: use uploaded SQLite by db_id if present, otherwise DEFAULT_SQLITE_DB.
66
+ """
67
+ if DB_MODE == "postgres":
68
+ if not POSTGRES_DSN:
69
+ raise HTTPException(status_code=500, detail="POSTGRES_DSN is not configured")
70
+ return PostgresAdapter(POSTGRES_DSN)
71
+
72
+ # sqlite mode
73
+ sqlite_path = _resolve_sqlite_path(db_id)
74
+ # NOTE: SQLiteAdapter should open DB in read-only mode internally if supported.
75
+ # If not, ensure your adapter enforces PRAGMA query_only=ON and prevents DDL/DML.
76
+ return SQLiteAdapter(sqlite_path)
77
+
78
+ # -------------------------------
79
+ # LLM providers & shared components (stateless)
80
+ # -------------------------------
81
  def get_llm():
82
  return OpenAIProvider()
83
 
84
+ _detector = AmbiguityDetector()
85
+ _planner = Planner(get_llm())
86
+ _generator = Generator(get_llm())
87
+ _safety = Safety()
88
  _verifier = Verifier()
89
  _repair = Repair(get_llm())
90
 
91
+ def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
92
+ """Build a fresh Pipeline with a per-request Executor bound to the chosen adapter."""
93
+ executor = Executor(adapter)
94
+ return Pipeline(
95
+ detector=_detector,
96
+ planner=_planner,
97
+ generator=_generator,
98
+ safety=_safety,
99
+ executor=executor,
100
+ verifier=_verifier,
101
+ repair=_repair,
102
+ )
103
 
104
+ # -------------------------------
105
+ # Helpers
106
+ # -------------------------------
 
 
 
 
 
 
 
 
107
  def _to_dict(obj):
108
+ """Safely convert dataclass → dict."""
109
  return asdict(obj) if is_dataclass(obj) else obj
110
 
 
111
  def _round_trace(t: dict) -> dict:
112
+ """Round float fields to keep responses tidy and stable."""
113
  if t.get("cost_usd") is not None:
114
  t["cost_usd"] = round(t["cost_usd"], 6)
115
  if t.get("duration_ms") is not None:
116
  t["duration_ms"] = round(t["duration_ms"], 2)
117
  return t
118
 
119
+ # -------------------------------
120
+ # Upload endpoint (SQLite only)
121
+ # Path will be /api/nl2sql/upload_db if your root APIRouter is mounted at /api
122
+ # -------------------------------
123
+ @router.post("/upload_db")
124
+ async def upload_db(file: UploadFile = File(...)):
125
+ """
126
+ Upload a SQLite database (.db/.sqlite). Returns a short-lived db_id.
127
+ Notes:
128
+ - Only SQLite files are allowed here (not for Postgres mode).
129
+ - Max size ~20MB recommended for demo environments like HF Spaces.
130
+ - Files are stored under /tmp and cleaned by TTL.
131
+ """
132
+ if DB_MODE != "sqlite":
133
+ raise HTTPException(status_code=400, detail="DB upload is only supported in sqlite mode")
134
+
135
+ filename = file.filename or "db.sqlite"
136
+ if not (filename.endswith(".db") or filename.endswith(".sqlite")):
137
+ raise HTTPException(status_code=400, detail="Only .db or .sqlite files are allowed")
138
+
139
+ data = await file.read()
140
+ max_bytes = int(os.getenv("UPLOAD_MAX_BYTES", str(20 * 1024 * 1024))) # 20 MB
141
+ if len(data) > max_bytes:
142
+ raise HTTPException(status_code=400, detail=f"File too large (> {max_bytes} bytes)")
143
+
144
+ db_id = str(uuid.uuid4())
145
+ out_path = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
146
+ try:
147
+ with open(out_path, "wb") as f:
148
+ f.write(data)
149
+ except Exception as e:
150
+ raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
151
+
152
+ _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
153
+ return {"db_id": db_id}
154
+
155
+ # -------------------------------
156
+ # Main NL2SQL endpoint
157
+ # Path will be /api/nl2sql if your root APIRouter is mounted at /api
158
+ # -------------------------------
159
  @router.post("", name="nl2sql_handler")
160
  def nl2sql_handler(request: NL2SQLRequest):
161
+ """
162
+ Handle NL → SQL pipeline execution.
163
+ Optional: if the incoming request model supports `db_id`, we switch DB for this call.
164
+ Otherwise we will silently ignore and use default DB (or Postgres, based on mode).
165
+ """
166
+ # Try to extract db_id if present in request (without breaking strict models)
167
+ db_id = getattr(request, "db_id", None) # Optional[str]
168
+ # Build per-request pipeline bound to the selected adapter
169
+ adapter = _select_adapter(db_id)
170
+ pipeline = _build_pipeline(adapter)
171
+
172
+ result = pipeline.run(
173
  user_query=request.query,
174
+ schema_preview=getattr(request, "schema_preview", None),
175
  )
176
 
177
+ # Ensure result type
178
  if not isinstance(result, FinalResult):
179
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
180
 
181
+ # Ambiguity: return clarify payload
182
  if result.ambiguous and result.questions:
183
  return ClarifyResponse(ambiguous=True, questions=result.questions)
184
 
185
+ # Error: bubble up details
186
  if not result.ok or result.error:
187
  detail = "; ".join(result.details or ["Unknown error"])
188
  raise HTTPException(status_code=400, detail=detail)
189
 
190
+ # Success
191
  traces = [_round_trace(t) for t in (result.traces or [])]
192
  return NL2SQLResponse(
193
  ambiguous=False,
demo/app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import gradio as gr
4
+
5
+ API_UPLOAD = "http://localhost:8000/api/v1/nl2sql/upload_db"
6
+ API_QUERY = "http://localhost:8000/api/v1/nl2sql"
7
+
8
+
9
+ def upload_db(file_obj):
10
+ if file_obj is None:
11
+ return None, "No DB uploaded. Default DB will be used."
12
+ name = getattr(file_obj, "name", "db.sqlite")
13
+ if not (name.endswith(".db") or name.endswith(".sqlite")):
14
+ return None, "Only .db or .sqlite files are allowed."
15
+ size = getattr(file_obj, "size", None)
16
+ if size and size > 20 * 1024 * 1024:
17
+ return None, "File too large (>20MB). Use a smaller demo DB."
18
+
19
+ # Read bytes
20
+ with open(file_obj.name, "rb") as f:
21
+ data = f.read()
22
+
23
+ r = requests.post(
24
+ API_UPLOAD,
25
+ files={"file": (name, io.BytesIO(data), "application/octet-stream")},
26
+ timeout=60,
27
+ )
28
+ r.raise_for_status()
29
+ db_id = r.json().get("db_id")
30
+ return db_id, f"Uploaded OK. db_id={db_id}"
31
+
32
+
33
+ def query_to_sql(user_query, db_id, debug):
34
+ payload = {"query": user_query, "debug": bool(debug)}
35
+ if db_id:
36
+ payload["db_id"] = db_id
37
+ r = requests.post(API_QUERY, json=payload, timeout=120)
38
+ r.raise_for_status()
39
+ d = r.json()
40
+
41
+ sql = d.get("sql_final") or d.get("sql") or ""
42
+ explanation = d.get("explanation", "")
43
+ result = d.get("result", [])
44
+
45
+ # Flags summary
46
+ ambiguous = "Yes" if d.get("ambiguous") else "No"
47
+ safety = ("Allowed" if d.get("safety", {}).get("allowed", True) else f"Blocked: {d.get('safety', {}).get('blocked_reason')}")
48
+ verification = ("Passed" if d.get("verification", {}).get("passed") else "Failed")
49
+ repair = d.get("repair", {})
50
+ repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
51
+
52
+ timings = d.get("timings_ms", {})
53
+ timings_table = [[k, timings[k]] for k in sorted(timings.keys())]
54
+
55
+ return (
56
+ f"Ambiguous: {ambiguous} | Safety: {safety} | Verification: {verification} | Repair: {repair_text}",
57
+ sql,
58
+ explanation,
59
+ result,
60
+ d.get("trace", []),
61
+ repair.get("candidates", []),
62
+ repair.get("diff", ""),
63
+ timings_table,
64
+ )
65
+
66
+
67
+ with gr.Blocks(title="NL2SQL Copilot") as demo:
68
+ gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.")
69
+
70
+ db_state = gr.State(value=None)
71
+
72
+ with gr.Row():
73
+ db_file = gr.File(label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"])
74
+ upload_btn = gr.Button("Upload DB")
75
+ db_msg = gr.Markdown()
76
+ upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])
77
+
78
+ with gr.Row():
79
+ q = gr.Textbox(label="Question", scale=4)
80
+ debug = gr.Checkbox(label="Debug", value=True, scale=1)
81
+ run = gr.Button("Run")
82
+
83
+ badges = gr.Markdown()
84
+ sql_out = gr.Code(label="Final SQL", language="sql")
85
+ exp_out = gr.Textbox(label="Explanation", lines=3)
86
+
87
+ with gr.Tab("Result"):
88
+ res_out = gr.JSON()
89
+
90
+ with gr.Tab("Trace"):
91
+ trace = gr.JSON(label="Stage trace")
92
+
93
+ with gr.Tab("Repair"):
94
+ repair_candidates = gr.JSON(label="Candidates")
95
+ repair_diff = gr.Code(label="SQL Diff", language="sql")
96
+
97
+ with gr.Tab("Timings"):
98
+ timings = gr.Dataframe(headers=["stage", "ms"], datatype=["str", "number"])
99
+
100
+ run.click(
101
+ query_to_sql,
102
+ inputs=[q, db_state, debug],
103
+ outputs=[badges, sql_out, exp_out, res_out, trace, repair_candidates, repair_diff, timings],
104
+ )
105
+
106
+ if __name__ == "__main__":
107
+ # Let Gradio pick a free port by default to avoid collisions
108
+ demo.launch()
requirements.txt CHANGED
@@ -9,4 +9,6 @@ pytest==8.3.3
9
  python-dotenv==1.1.1
10
  openai==2.6.1
11
  psycopg[binary]~=3.2
12
- ruff
 
 
 
9
  python-dotenv==1.1.1
10
  openai==2.6.1
11
  psycopg[binary]~=3.2
12
+ ruff
13
+ gradio
14
+ sqlalchemy