Spaces:
Sleeping
Sleeping
| import requests | |
| import gradio as gr | |
| import os | |
| import json | |
| from pathlib import Path | |
| # Prefer internal backend when running inside Docker | |
| API_HOST = os.getenv("API_HOST", "localhost") | |
| API_PORT = os.getenv("API_PORT", "8000") | |
| USE_MOCK = os.environ.get("USE_MOCK", "0") == "1" | |
| API_UPLOAD = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql/upload_db" | |
| API_QUERY = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql" | |
| HARDCODED_MOCK = { | |
| "sql": "SELECT name, country FROM singer WHERE age > 20;", | |
| "rationale": "Example: select singers older than 20.", | |
| "result": { | |
| "rows": 5, | |
| "columns": ["name", "country"], | |
| "rows_data": [["Alice", "France"], ["Bob", "USA"]], | |
| }, | |
| "traces": [ | |
| {"stage": "detector", "summary": "ok", "duration_ms": 5}, | |
| {"stage": "planner", "summary": "intent parsed", "duration_ms": 120}, | |
| {"stage": "generator", "summary": "sql generated", "duration_ms": 420}, | |
| {"stage": "verifier", "summary": "passed", "duration_ms": 10}, | |
| ], | |
| "metrics": {"EM": 0.15, "SM": 0.70, "ExecAcc": 0.73, "avg_latency_ms": 8113}, | |
| } | |
| def load_mock_from_summary(): | |
| """Try to read latest benchmark summary.json; fallback to hardcoded mock.""" | |
| try: | |
| files = sorted( | |
| Path("benchmarks/results_pro").glob("*/summary.json"), | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True, | |
| ) | |
| if files: | |
| p = files[0] | |
| with open(p, "r", encoding="utf-8") as f: | |
| sj = json.load(f) | |
| return { | |
| "sql": sj.get("example_sql", HARDCODED_MOCK["sql"]), | |
| "rationale": sj.get("note", HARDCODED_MOCK["rationale"]), | |
| "result": {"rows": sj.get("total_samples", 0), "columns": []}, | |
| "traces": HARDCODED_MOCK["traces"], | |
| "metrics": { | |
| "EM": sj.get("avg_em", HARDCODED_MOCK["metrics"]["EM"]), | |
| "SM": sj.get("avg_sm", HARDCODED_MOCK["metrics"]["SM"]), | |
| "ExecAcc": sj.get( | |
| "avg_execacc", HARDCODED_MOCK["metrics"]["ExecAcc"] | |
| ), | |
| "avg_latency_ms": sj.get( | |
| "avg_latency_ms", HARDCODED_MOCK["metrics"]["avg_latency_ms"] | |
| ), | |
| }, | |
| } | |
| except Exception: | |
| pass | |
| return HARDCODED_MOCK | |
| def call_pipeline_api_or_mock(query: str, db_id: str | None = None, timeout=10): | |
| """Call backend if available; otherwise return mock.""" | |
| if USE_MOCK: | |
| return load_mock_from_summary() | |
| try: | |
| payload = {"query": query} | |
| if db_id: | |
| payload["db_id"] = db_id | |
| r = requests.post(API_QUERY, json=payload, timeout=timeout) | |
| r.raise_for_status() | |
| return r.json() | |
| except Exception as e: | |
| print(f"[demo] API call failed ({e}); using mock instead.") | |
| return load_mock_from_summary() | |
| def upload_db(file_obj): | |
| if file_obj is None: | |
| return None, "No DB uploaded. Default DB will be used." | |
| name = getattr(file_obj, "name", "db.sqlite") | |
| if not (name.endswith(".db") or name.endswith(".sqlite")): | |
| return None, "Only .db or .sqlite files are allowed." | |
| size = getattr(file_obj, "size", None) | |
| if size and size > 20 * 1024 * 1024: | |
| return None, "File too large (>20MB) for this demo." | |
| # Gradio gives a temp file path as file_obj.name | |
| files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")} | |
| try: | |
| r = requests.post(API_UPLOAD, files=files, timeout=120) | |
| finally: | |
| # best-effort close | |
| try: | |
| files["file"][1].close() | |
| except Exception: | |
| pass | |
| if r.ok: | |
| data = r.json() | |
| return data.get("db_id"), f"Uploaded OK. db_id={data.get('db_id')}" | |
| try: | |
| body = r.json() | |
| except ValueError: | |
| body = r.text | |
| return None, f"Upload failed ({r.status_code}): {body}" | |
| def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool): | |
| """Unified query handler: tries backend or mock fallback.""" | |
| if not user_query.strip(): | |
| return "❌ Please enter a query.", "", "", {}, [], [], "", [] | |
| data = call_pipeline_api_or_mock(user_query, db_id) | |
| sql = data.get("sql") or "" | |
| explanation = data.get("rationale") or "" | |
| result = data.get("result", {}) | |
| trace_list = data.get("traces", []) | |
| metrics = data.get("metrics", {}) | |
| badges_text = ( | |
| f"EM={metrics.get('EM', '?')} | SM={metrics.get('SM', '?')} | " | |
| f"ExecAcc={metrics.get('ExecAcc', '?')} | latency={metrics.get('avg_latency_ms', '?')}ms" | |
| ) | |
| timings_table = [] | |
| if trace_list and all("duration_ms" in t for t in trace_list): | |
| timings_table = [[t["stage"], t["duration_ms"]] for t in trace_list] | |
| return badges_text, sql, explanation, result, trace_list, [], "", timings_table | |
| # ---- UI definition (unchanged) ---- | |
| with gr.Blocks(title="NL2SQL Copilot") as demo: | |
| gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.") | |
| db_state = gr.State(value=None) | |
| with gr.Row(): | |
| db_file = gr.File( | |
| label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"] | |
| ) | |
| upload_btn = gr.Button("Upload DB") | |
| db_msg = gr.Markdown() | |
| upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg]) | |
| with gr.Row(): | |
| q = gr.Textbox(label="Question", scale=4) | |
| debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1) | |
| run = gr.Button("Run") | |
| badges = gr.Markdown() | |
| sql_out = gr.Code(label="Final SQL", language="sql") | |
| exp_out = gr.Textbox(label="Explanation", lines=3) | |
| with gr.Tab("Result"): | |
| res_out = gr.JSON() | |
| with gr.Tab("Trace"): | |
| trace = gr.JSON(label="Stage trace") | |
| with gr.Tab("Repair"): | |
| repair_candidates = gr.JSON(label="Candidates") | |
| repair_diff = gr.Textbox(label="Diff (if any)", lines=10) | |
| with gr.Tab("Timings"): | |
| timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"]) | |
| run.click( | |
| query_to_sql, | |
| inputs=[q, db_state, debug], | |
| outputs=[ | |
| badges, | |
| sql_out, | |
| exp_out, | |
| res_out, | |
| trace, | |
| repair_candidates, | |
| repair_diff, | |
| timings, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| import os | |
| print("[demo] Launching Gradio demo on 0.0.0.0:7860 ...", flush=True) | |
| demo.launch( | |
| server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")), | |
| share=False, | |
| debug=True, | |
| ) | |