Melika Kheirieh
demo(ui): align POST payload with API schema; robust 422 handling; map (sql,rationale,traces) to UI
c4c85f7
raw
history blame
5.65 kB
import requests
import gradio as gr
API_UPLOAD = "http://localhost:8000/api/v1/nl2sql/upload_db"
API_QUERY = "http://localhost:8000/api/v1/nl2sql"
def upload_db(file_obj):
if file_obj is None:
return None, "No DB uploaded. Default DB will be used."
name = getattr(file_obj, "name", "db.sqlite")
if not (name.endswith(".db") or name.endswith(".sqlite")):
return None, "Only .db or .sqlite files are allowed."
size = getattr(file_obj, "size", None)
if size and size > 20 * 1024 * 1024:
return None, "File too large (>20MB) for this demo."
# Gradio gives a temp file path as file_obj.name
files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
try:
r = requests.post(API_UPLOAD, files=files, timeout=120)
finally:
# best-effort close
try:
files["file"][1].close()
except Exception:
pass
if r.ok:
data = r.json()
db_id = data.get("db_id")
if not db_id:
return None, f"Upload returned no db_id: {data}"
return db_id, f"Uploaded OK. db_id={db_id}"
# Show backend error body
try:
body = r.json()
except ValueError:
body = r.text
return None, f"Upload failed ({r.status_code}): {body}"
def _post_query(payload: dict):
"""Helper: POST and return (ok, data_or_error_string)."""
r = requests.post(API_QUERY, json=payload, timeout=120)
if r.ok:
try:
return True, r.json()
except ValueError:
return False, "Backend returned non-JSON body."
try:
body = r.json()
except ValueError:
body = r.text
return False, f"{r.status_code} {body}"
def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
# Build minimal schema-compliant payload.
# Server expects request.query (name is 'query' per router code).
base_payload = {"query": user_query.strip() if user_query else ""}
# First try WITH db_id (if present). If backend rejects (422), retry WITHOUT.
if db_id:
ok, data = _post_query({**base_payload, "db_id": db_id})
if not ok and isinstance(data, str) and data.startswith("422"):
# Retry without db_id in case request model forbids extra fields.
ok, data = _post_query(base_payload)
else:
ok, data = _post_query(base_payload)
if not ok:
# Surface backend error text to the UI
err_badges = f"Error: {data}"
return (
err_badges, # badges
"", # sql_out
"", # exp_out
{}, # result (tab)
[], # trace (tab)
[], # repair_candidates (tab)
"", # repair_diff (tab)
[], # timings (tab)
)
d = data
# Map fields to UI (server returns: ambiguous, sql, rationale, traces)
sql = d.get("sql") or d.get("sql_final") or ""
explanation = d.get("rationale") or d.get("explanation") or ""
result = d.get("result", {}) # optional/maybe absent
trace_list = d.get("traces") or d.get("trace") or []
ambiguous_flag = "Yes" if d.get("ambiguous") else "No"
safety = (
"Allowed"
if d.get("safety", {}).get("allowed", True)
else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
)
verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
repair = d.get("repair", {}) or {}
repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
timings = d.get("timings_ms", {}) or {}
timings_table = [[k, timings[k]] for k in sorted(timings.keys())]
badges_text = f"Ambiguous: {ambiguous_flag} | Safety: {safety} | Verification: {verification} | Repair: {repair_text}"
return (
badges_text,
sql,
explanation,
result,
trace_list,
repair.get("candidates", []),
repair.get("diff", ""),
timings_table,
)
with gr.Blocks(title="NL2SQL Copilot") as demo:
gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.")
db_state = gr.State(value=None)
with gr.Row():
db_file = gr.File(
label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"]
)
upload_btn = gr.Button("Upload DB")
db_msg = gr.Markdown()
upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])
with gr.Row():
q = gr.Textbox(label="Question", scale=4)
# keep the checkbox in UI if you like, but we don't send it to backend
debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
run = gr.Button("Run")
badges = gr.Markdown()
sql_out = gr.Code(label="Final SQL", language="sql")
exp_out = gr.Textbox(label="Explanation", lines=3)
with gr.Tab("Result"):
res_out = gr.JSON()
with gr.Tab("Trace"):
trace = gr.JSON(label="Stage trace")
with gr.Tab("Repair"):
repair_candidates = gr.JSON(label="Candidates")
repair_diff = gr.Textbox(label="Diff (if any)", lines=10)
with gr.Tab("Timings"):
timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"])
run.click(
query_to_sql,
inputs=[q, db_state, debug],
outputs=[
badges,
sql_out,
exp_out,
res_out,
trace,
repair_candidates,
repair_diff,
timings,
],
)
if __name__ == "__main__":
# Let Gradio pick a free port by default to avoid collisions
demo.launch()