Melika Kheirieh
refactor(ui): clean up Repair tab, sync outputs with backend and add examples
c9bbfcd
raw
history blame
7.85 kB
import requests
import gradio as gr
import os
import json
from pathlib import Path
# Prefer internal backend when running inside Docker
API_HOST = os.getenv("API_HOST", "localhost")
API_PORT = os.getenv("API_PORT", "8000")
USE_MOCK = os.environ.get("USE_MOCK", "0") == "1"
API_UPLOAD = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql/upload_db"
API_QUERY = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql"
HARDCODED_MOCK = {
"sql": "SELECT name, country FROM singer WHERE age > 20;",
"rationale": "Example: select singers older than 20.",
"result": {
"rows": 5,
"columns": ["name", "country"],
"rows_data": [["Alice", "France"], ["Bob", "USA"]],
},
"traces": [
{"stage": "detector", "summary": "ok", "duration_ms": 5},
{"stage": "planner", "summary": "intent parsed", "duration_ms": 120},
{"stage": "generator", "summary": "sql generated", "duration_ms": 420},
{"stage": "verifier", "summary": "passed", "duration_ms": 10},
],
"metrics": {"EM": 0.15, "SM": 0.70, "ExecAcc": 0.73, "avg_latency_ms": 8113},
}
def load_mock_from_summary():
"""Try to read latest benchmark summary.json; fallback to hardcoded mock."""
try:
files = sorted(
Path("benchmarks/results_pro").glob("*/summary.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
if files:
p = files[0]
with open(p, "r", encoding="utf-8") as f:
sj = json.load(f)
return {
"sql": sj.get("example_sql", HARDCODED_MOCK["sql"]),
"rationale": sj.get("note", HARDCODED_MOCK["rationale"]),
"result": {"rows": sj.get("total_samples", 0), "columns": []},
"traces": HARDCODED_MOCK["traces"],
"metrics": {
"EM": sj.get("avg_em", HARDCODED_MOCK["metrics"]["EM"]),
"SM": sj.get("avg_sm", HARDCODED_MOCK["metrics"]["SM"]),
"ExecAcc": sj.get(
"avg_execacc", HARDCODED_MOCK["metrics"]["ExecAcc"]
),
"avg_latency_ms": sj.get(
"avg_latency_ms", HARDCODED_MOCK["metrics"]["avg_latency_ms"]
),
},
}
except Exception:
pass
return HARDCODED_MOCK
def call_pipeline_api_or_mock(query: str, db_id: str | None = None, timeout=10):
"""Call backend if available; otherwise return mock."""
if USE_MOCK:
return load_mock_from_summary()
try:
payload = {"query": query}
if db_id:
payload["db_id"] = db_id
r = requests.post(API_QUERY, json=payload, timeout=timeout)
r.raise_for_status()
return r.json()
except Exception as e:
print(f"[demo] API call failed ({e}); using mock instead.")
return load_mock_from_summary()
def upload_db(file_obj):
if file_obj is None:
return None, "No DB uploaded. Default DB will be used."
name = getattr(file_obj, "name", "db.sqlite")
if not (name.endswith(".db") or name.endswith(".sqlite")):
return None, "Only .db or .sqlite files are allowed."
size = getattr(file_obj, "size", None)
if size and size > 20 * 1024 * 1024:
return None, "File too large (>20MB) for this demo."
files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
try:
r = requests.post(API_UPLOAD, files=files, timeout=120)
finally:
try:
files["file"][1].close()
except Exception:
pass
if r.ok:
data = r.json()
return data.get("db_id"), f"Uploaded OK. db_id={data.get('db_id')}"
try:
body = r.json()
except ValueError:
body = r.text
return None, f"Upload failed ({r.status_code}): {body}"
def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
if not user_query.strip():
return "❌ Please enter a query.", "", "", {}, [], []
data = call_pipeline_api_or_mock(user_query, db_id)
sql = data.get("sql") or ""
explanation = data.get("rationale") or ""
result = data.get("result", {})
trace_list = data.get("traces", [])
metrics = data.get("metrics", {})
badges_text = (
f"EM={metrics.get('EM', '?')} | SM={metrics.get('SM', '?')} | "
f"ExecAcc={metrics.get('ExecAcc', '?')} | latency={metrics.get('avg_latency_ms', '?')}ms"
)
timings_table = []
if trace_list and all("duration_ms" in t for t in trace_list):
timings_table = [[t["stage"], t["duration_ms"]] for t in trace_list]
# Note: repair candidates / diff are not exposed in the UI yet.
return badges_text, sql, explanation, result, trace_list, timings_table
def build_ui() -> gr.Blocks:
with gr.Blocks(title="NL2SQL Copilot") as demo:
gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.")
db_state = gr.State(value=None)
with gr.Row():
db_file = gr.File(
label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"]
)
upload_btn = gr.Button("Upload DB")
db_msg = gr.Markdown()
upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])
with gr.Row():
q = gr.Textbox(label="Question", scale=4)
debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
run = gr.Button("Run")
# Example queries to make the demo easier to explore
gr.Examples(
examples=[
["List all artists"],
["Top 5 customers by total invoice amount"],
["Total number of tracks per genre"],
["Top 3 albums by total sales"],
],
inputs=[q],
label="Try these example queries",
)
badges = gr.Markdown()
sql_out = gr.Code(label="Final SQL", language="sql")
exp_out = gr.Textbox(label="Explanation", lines=3)
with gr.Tab("Result"):
res_out = gr.JSON()
with gr.Tab("Trace"):
trace = gr.JSON(label="Stage trace")
with gr.Tab("Repair"):
gr.Markdown(
"""
### Repair & self-healing (pipeline-level)
The repair loop is fully implemented in the backend:
* If a candidate SQL fails safety or execution checks,
the pipeline attempts to **repair** it.
* All repair attempts and outcomes are tracked in Prometheus
(for example, `nl2sql_repair_attempts_total` and related rates).
For now, detailed before/after SQL diff and repair candidates
are exposed via trace logs and metrics dashboards.
This tab is reserved for a future, richer UI:
side-by-side SQL diff, repair candidates, and explanations.
"""
)
with gr.Tab("Timings"):
timings = gr.Dataframe(
headers=["stage", "duration_ms"], datatype=["str", "number"]
)
run.click(
query_to_sql,
inputs=[q, db_state, debug],
outputs=[
badges,
sql_out,
exp_out,
res_out,
trace,
timings,
],
)
return demo
# expose for SDK mode (no Docker)
demo = build_ui()
if __name__ == "__main__":
print("[demo] Launching Gradio demo on 0.0.0.0:7860 ...", flush=True)
demo.launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
share=False,
debug=True,
)