Melika Kheirieh commited on
Commit
1c9c65e
·
1 Parent(s): 5cbfffe

fix(ui): align payload with API schema and handle 422 errors

Browse files
Files changed (1) hide show
  1. demo/app.py +86 -35
demo/app.py CHANGED
@@ -1,4 +1,3 @@
1
- import io
2
  import requests
3
  import gradio as gr
4
 
@@ -14,54 +13,105 @@ def upload_db(file_obj):
14
  return None, "Only .db or .sqlite files are allowed."
15
  size = getattr(file_obj, "size", None)
16
  if size and size > 20 * 1024 * 1024:
17
- return None, "File too large (>20MB). Use a smaller demo DB."
18
-
19
- # Read bytes
20
- with open(file_obj.name, "rb") as f:
21
- data = f.read()
22
-
23
- r = requests.post(
24
- API_UPLOAD,
25
- files={"file": (name, io.BytesIO(data), "application/octet-stream")},
26
- timeout=60,
27
- )
28
- r.raise_for_status()
29
- db_id = r.json().get("db_id")
30
- return db_id, f"Uploaded OK. db_id={db_id}"
31
-
32
-
33
- def query_to_sql(user_query, db_id, debug):
34
- payload = {"query": user_query, "debug": bool(debug)}
35
- if db_id:
36
- payload["db_id"] = db_id
 
 
 
 
 
 
 
 
 
37
  r = requests.post(API_QUERY, json=payload, timeout=120)
38
- r.raise_for_status()
39
- d = r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- sql = d.get("sql_final") or d.get("sql") or ""
42
- explanation = d.get("explanation", "")
43
- result = d.get("result", [])
44
 
45
- # Flags summary
46
- ambiguous = "Yes" if d.get("ambiguous") else "No"
 
 
 
 
 
47
  safety = (
48
  "Allowed"
49
  if d.get("safety", {}).get("allowed", True)
50
  else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
51
  )
52
  verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
53
- repair = d.get("repair", {})
54
  repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
55
 
56
- timings = d.get("timings_ms", {})
57
  timings_table = [[k, timings[k]] for k in sorted(timings.keys())]
58
 
 
 
59
  return (
60
- f"Ambiguous: {ambiguous} | Safety: {safety} | Verification: {verification} | Repair: {repair_text}",
61
  sql,
62
  explanation,
63
  result,
64
- d.get("trace", []),
65
  repair.get("candidates", []),
66
  repair.get("diff", ""),
67
  timings_table,
@@ -83,7 +133,8 @@ with gr.Blocks(title="NL2SQL Copilot") as demo:
83
 
84
  with gr.Row():
85
  q = gr.Textbox(label="Question", scale=4)
86
- debug = gr.Checkbox(label="Debug", value=True, scale=1)
 
87
  run = gr.Button("Run")
88
 
89
  badges = gr.Markdown()
@@ -98,10 +149,10 @@ with gr.Blocks(title="NL2SQL Copilot") as demo:
98
 
99
  with gr.Tab("Repair"):
100
  repair_candidates = gr.JSON(label="Candidates")
101
- repair_diff = gr.Code(label="SQL Diff", language="sql")
102
 
103
  with gr.Tab("Timings"):
104
- timings = gr.Dataframe(headers=["stage", "ms"], datatype=["str", "number"])
105
 
106
  run.click(
107
  query_to_sql,
 
 
1
  import requests
2
  import gradio as gr
3
 
 
13
  return None, "Only .db or .sqlite files are allowed."
14
  size = getattr(file_obj, "size", None)
15
  if size and size > 20 * 1024 * 1024:
16
+ return None, "File too large (>20MB) for this demo."
17
+
18
+ # Gradio gives a temp file path as file_obj.name
19
+ files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
20
+ try:
21
+ r = requests.post(API_UPLOAD, files=files, timeout=120)
22
+ finally:
23
+ # best-effort close
24
+ try:
25
+ files["file"][1].close()
26
+ except Exception:
27
+ pass
28
+
29
+ if r.ok:
30
+ data = r.json()
31
+ db_id = data.get("db_id")
32
+ if not db_id:
33
+ return None, f"Upload returned no db_id: {data}"
34
+ return db_id, f"Uploaded OK. db_id={db_id}"
35
+ # Show backend error body
36
+ try:
37
+ body = r.json()
38
+ except ValueError:
39
+ body = r.text
40
+ return None, f"Upload failed ({r.status_code}): {body}"
41
+
42
+
43
+ def _post_query(payload: dict):
44
+ """Helper: POST and return (ok, data_or_error_string)."""
45
  r = requests.post(API_QUERY, json=payload, timeout=120)
46
+ if r.ok:
47
+ try:
48
+ return True, r.json()
49
+ except ValueError:
50
+ return False, "Backend returned non-JSON body."
51
+ try:
52
+ body = r.json()
53
+ except ValueError:
54
+ body = r.text
55
+ return False, f"{r.status_code} {body}"
56
+
57
+
58
+ def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
59
+ # Build minimal schema-compliant payload.
60
+ # Server expects request.query (name is 'query' per router code).
61
+ base_payload = {"query": user_query.strip() if user_query else ""}
62
+
63
+ # First try WITH db_id (if present). If backend rejects (422), retry WITHOUT.
64
+ if db_id:
65
+ ok, data = _post_query({**base_payload, "db_id": db_id})
66
+ if not ok and isinstance(data, str) and data.startswith("422"):
67
+ # Retry without db_id in case request model forbids extra fields.
68
+ ok, data = _post_query(base_payload)
69
+ else:
70
+ ok, data = _post_query(base_payload)
71
+
72
+ if not ok:
73
+ # Surface backend error text to the UI
74
+ err_badges = f"Error: {data}"
75
+ return (
76
+ err_badges, # badges
77
+ "", # sql_out
78
+ "", # exp_out
79
+ {}, # result (tab)
80
+ [], # trace (tab)
81
+ [], # repair_candidates (tab)
82
+ "", # repair_diff (tab)
83
+ [], # timings (tab)
84
+ )
85
 
86
+ d = data
 
 
87
 
88
+ # Map fields to UI (server returns: ambiguous, sql, rationale, traces)
89
+ sql = d.get("sql") or d.get("sql_final") or ""
90
+ explanation = d.get("rationale") or d.get("explanation") or ""
91
+ result = d.get("result", {}) # optional/maybe absent
92
+ trace_list = d.get("traces") or d.get("trace") or []
93
+
94
+ ambiguous_flag = "Yes" if d.get("ambiguous") else "No"
95
  safety = (
96
  "Allowed"
97
  if d.get("safety", {}).get("allowed", True)
98
  else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
99
  )
100
  verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
101
+ repair = d.get("repair", {}) or {}
102
  repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"
103
 
104
+ timings = d.get("timings_ms", {}) or {}
105
  timings_table = [[k, timings[k]] for k in sorted(timings.keys())]
106
 
107
+ badges_text = f"Ambiguous: {ambiguous_flag} | Safety: {safety} | Verification: {verification} | Repair: {repair_text}"
108
+
109
  return (
110
+ badges_text,
111
  sql,
112
  explanation,
113
  result,
114
+ trace_list,
115
  repair.get("candidates", []),
116
  repair.get("diff", ""),
117
  timings_table,
 
133
 
134
  with gr.Row():
135
  q = gr.Textbox(label="Question", scale=4)
136
+ # keep the checkbox in UI if you like, but we don't send it to backend
137
+ debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
138
  run = gr.Button("Run")
139
 
140
  badges = gr.Markdown()
 
149
 
150
  with gr.Tab("Repair"):
151
  repair_candidates = gr.JSON(label="Candidates")
152
+ repair_diff = gr.Code(label="Diff (if any)", language="diff")
153
 
154
  with gr.Tab("Timings"):
155
+ timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"])
156
 
157
  run.click(
158
  query_to_sql,