Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
1c9c65e
1
Parent(s):
5cbfffe
fix(ui): align payload with API schema and handle 422 errors
Browse files- demo/app.py +86 -35
demo/app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import io
|
| 2 |
import requests
|
| 3 |
import gradio as gr
|
| 4 |
|
|
@@ -14,54 +13,105 @@ def upload_db(file_obj):
|
|
| 14 |
return None, "Only .db or .sqlite files are allowed."
|
| 15 |
size = getattr(file_obj, "size", None)
|
| 16 |
if size and size > 20 * 1024 * 1024:
|
| 17 |
-
return None, "File too large (>20MB)
|
| 18 |
-
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
r = requests.post(API_QUERY, json=payload, timeout=120)
|
| 38 |
-
r.
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
explanation = d.get("explanation", "")
|
| 43 |
-
result = d.get("result", [])
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
safety = (
|
| 48 |
"Allowed"
|
| 49 |
if d.get("safety", {}).get("allowed", True)
|
| 50 |
else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
|
| 51 |
)
|
| 52 |
verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
|
| 53 |
-
repair = d.get("repair", {})
|
| 54 |
repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
|
| 55 |
|
| 56 |
-
timings = d.get("timings_ms", {})
|
| 57 |
timings_table = [[k, timings[k]] for k in sorted(timings.keys())]
|
| 58 |
|
|
|
|
|
|
|
| 59 |
return (
|
| 60 |
-
|
| 61 |
sql,
|
| 62 |
explanation,
|
| 63 |
result,
|
| 64 |
-
|
| 65 |
repair.get("candidates", []),
|
| 66 |
repair.get("diff", ""),
|
| 67 |
timings_table,
|
|
@@ -83,7 +133,8 @@ with gr.Blocks(title="NL2SQL Copilot") as demo:
|
|
| 83 |
|
| 84 |
with gr.Row():
|
| 85 |
q = gr.Textbox(label="Question", scale=4)
|
| 86 |
-
|
|
|
|
| 87 |
run = gr.Button("Run")
|
| 88 |
|
| 89 |
badges = gr.Markdown()
|
|
@@ -98,10 +149,10 @@ with gr.Blocks(title="NL2SQL Copilot") as demo:
|
|
| 98 |
|
| 99 |
with gr.Tab("Repair"):
|
| 100 |
repair_candidates = gr.JSON(label="Candidates")
|
| 101 |
-
repair_diff = gr.Code(label="
|
| 102 |
|
| 103 |
with gr.Tab("Timings"):
|
| 104 |
-
timings = gr.Dataframe(headers=["
|
| 105 |
|
| 106 |
run.click(
|
| 107 |
query_to_sql,
|
|
|
|
|
|
|
| 1 |
import requests
|
| 2 |
import gradio as gr
|
| 3 |
|
|
|
|
| 13 |
return None, "Only .db or .sqlite files are allowed."
|
| 14 |
size = getattr(file_obj, "size", None)
|
| 15 |
if size and size > 20 * 1024 * 1024:
|
| 16 |
+
return None, "File too large (>20MB) for this demo."
|
| 17 |
+
|
| 18 |
+
# Gradio gives a temp file path as file_obj.name
|
| 19 |
+
files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
|
| 20 |
+
try:
|
| 21 |
+
r = requests.post(API_UPLOAD, files=files, timeout=120)
|
| 22 |
+
finally:
|
| 23 |
+
# best-effort close
|
| 24 |
+
try:
|
| 25 |
+
files["file"][1].close()
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
if r.ok:
|
| 30 |
+
data = r.json()
|
| 31 |
+
db_id = data.get("db_id")
|
| 32 |
+
if not db_id:
|
| 33 |
+
return None, f"Upload returned no db_id: {data}"
|
| 34 |
+
return db_id, f"Uploaded OK. db_id={db_id}"
|
| 35 |
+
# Show backend error body
|
| 36 |
+
try:
|
| 37 |
+
body = r.json()
|
| 38 |
+
except ValueError:
|
| 39 |
+
body = r.text
|
| 40 |
+
return None, f"Upload failed ({r.status_code}): {body}"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _post_query(payload: dict):
|
| 44 |
+
"""Helper: POST and return (ok, data_or_error_string)."""
|
| 45 |
r = requests.post(API_QUERY, json=payload, timeout=120)
|
| 46 |
+
if r.ok:
|
| 47 |
+
try:
|
| 48 |
+
return True, r.json()
|
| 49 |
+
except ValueError:
|
| 50 |
+
return False, "Backend returned non-JSON body."
|
| 51 |
+
try:
|
| 52 |
+
body = r.json()
|
| 53 |
+
except ValueError:
|
| 54 |
+
body = r.text
|
| 55 |
+
return False, f"{r.status_code} {body}"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
|
| 59 |
+
# Build minimal schema-compliant payload.
|
| 60 |
+
# Server expects request.query (name is 'query' per router code).
|
| 61 |
+
base_payload = {"query": user_query.strip() if user_query else ""}
|
| 62 |
+
|
| 63 |
+
# First try WITH db_id (if present). If backend rejects (422), retry WITHOUT.
|
| 64 |
+
if db_id:
|
| 65 |
+
ok, data = _post_query({**base_payload, "db_id": db_id})
|
| 66 |
+
if not ok and isinstance(data, str) and data.startswith("422"):
|
| 67 |
+
# Retry without db_id in case request model forbids extra fields.
|
| 68 |
+
ok, data = _post_query(base_payload)
|
| 69 |
+
else:
|
| 70 |
+
ok, data = _post_query(base_payload)
|
| 71 |
+
|
| 72 |
+
if not ok:
|
| 73 |
+
# Surface backend error text to the UI
|
| 74 |
+
err_badges = f"Error: {data}"
|
| 75 |
+
return (
|
| 76 |
+
err_badges, # badges
|
| 77 |
+
"", # sql_out
|
| 78 |
+
"", # exp_out
|
| 79 |
+
{}, # result (tab)
|
| 80 |
+
[], # trace (tab)
|
| 81 |
+
[], # repair_candidates (tab)
|
| 82 |
+
"", # repair_diff (tab)
|
| 83 |
+
[], # timings (tab)
|
| 84 |
+
)
|
| 85 |
|
| 86 |
+
d = data
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
# Map fields to UI (server returns: ambiguous, sql, rationale, traces)
|
| 89 |
+
sql = d.get("sql") or d.get("sql_final") or ""
|
| 90 |
+
explanation = d.get("rationale") or d.get("explanation") or ""
|
| 91 |
+
result = d.get("result", {}) # optional/maybe absent
|
| 92 |
+
trace_list = d.get("traces") or d.get("trace") or []
|
| 93 |
+
|
| 94 |
+
ambiguous_flag = "Yes" if d.get("ambiguous") else "No"
|
| 95 |
safety = (
|
| 96 |
"Allowed"
|
| 97 |
if d.get("safety", {}).get("allowed", True)
|
| 98 |
else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
|
| 99 |
)
|
| 100 |
verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
|
| 101 |
+
repair = d.get("repair", {}) or {}
|
| 102 |
repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
|
| 103 |
|
| 104 |
+
timings = d.get("timings_ms", {}) or {}
|
| 105 |
timings_table = [[k, timings[k]] for k in sorted(timings.keys())]
|
| 106 |
|
| 107 |
+
badges_text = f"Ambiguous: {ambiguous_flag} | Safety: {safety} | Verification: {verification} | Repair: {repair_text}"
|
| 108 |
+
|
| 109 |
return (
|
| 110 |
+
badges_text,
|
| 111 |
sql,
|
| 112 |
explanation,
|
| 113 |
result,
|
| 114 |
+
trace_list,
|
| 115 |
repair.get("candidates", []),
|
| 116 |
repair.get("diff", ""),
|
| 117 |
timings_table,
|
|
|
|
| 133 |
|
| 134 |
with gr.Row():
|
| 135 |
q = gr.Textbox(label="Question", scale=4)
|
| 136 |
+
# keep the checkbox in UI if you like, but we don't send it to backend
|
| 137 |
+
debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
|
| 138 |
run = gr.Button("Run")
|
| 139 |
|
| 140 |
badges = gr.Markdown()
|
|
|
|
| 149 |
|
| 150 |
with gr.Tab("Repair"):
|
| 151 |
repair_candidates = gr.JSON(label="Candidates")
|
| 152 |
+
repair_diff = gr.Code(label="Diff (if any)", language="diff")
|
| 153 |
|
| 154 |
with gr.Tab("Timings"):
|
| 155 |
+
timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"])
|
| 156 |
|
| 157 |
run.click(
|
| 158 |
query_to_sql,
|