File size: 6,766 Bytes
b568b83
 
aac8545
 
 
b568b83
3c2f1c5
 
 
 
aac8545
3c2f1c5
 
b568b83
aac8545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
 
 
 
 
 
 
 
 
1c9c65e
 
 
 
 
 
 
 
 
 
 
 
 
aac8545
1c9c65e
 
 
 
 
 
 
 
aac8545
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
aac8545
 
 
1c9c65e
aac8545
b568b83
 
a5a91fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
 
a5a91fa
 
b568b83
 
c8b0bcb
98694e9
 
 
c8b0bcb
 
98694e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import requests
import gradio as gr
import os
import json
from pathlib import Path

# Prefer internal backend when running inside Docker
API_HOST = os.getenv("API_HOST", "localhost")
API_PORT = os.getenv("API_PORT", "8000")

USE_MOCK = os.environ.get("USE_MOCK", "0") == "1"
API_UPLOAD = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql/upload_db"
API_QUERY = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql"

HARDCODED_MOCK = {
    "sql": "SELECT name, country FROM singer WHERE age > 20;",
    "rationale": "Example: select singers older than 20.",
    "result": {
        "rows": 5,
        "columns": ["name", "country"],
        "rows_data": [["Alice", "France"], ["Bob", "USA"]],
    },
    "traces": [
        {"stage": "detector", "summary": "ok", "duration_ms": 5},
        {"stage": "planner", "summary": "intent parsed", "duration_ms": 120},
        {"stage": "generator", "summary": "sql generated", "duration_ms": 420},
        {"stage": "verifier", "summary": "passed", "duration_ms": 10},
    ],
    "metrics": {"EM": 0.15, "SM": 0.70, "ExecAcc": 0.73, "avg_latency_ms": 8113},
}


def load_mock_from_summary():
    """Try to read latest benchmark summary.json; fallback to hardcoded mock."""
    try:
        files = sorted(
            Path("benchmarks/results_pro").glob("*/summary.json"),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )
        if files:
            p = files[0]
            with open(p, "r", encoding="utf-8") as f:
                sj = json.load(f)
            return {
                "sql": sj.get("example_sql", HARDCODED_MOCK["sql"]),
                "rationale": sj.get("note", HARDCODED_MOCK["rationale"]),
                "result": {"rows": sj.get("total_samples", 0), "columns": []},
                "traces": HARDCODED_MOCK["traces"],
                "metrics": {
                    "EM": sj.get("avg_em", HARDCODED_MOCK["metrics"]["EM"]),
                    "SM": sj.get("avg_sm", HARDCODED_MOCK["metrics"]["SM"]),
                    "ExecAcc": sj.get(
                        "avg_execacc", HARDCODED_MOCK["metrics"]["ExecAcc"]
                    ),
                    "avg_latency_ms": sj.get(
                        "avg_latency_ms", HARDCODED_MOCK["metrics"]["avg_latency_ms"]
                    ),
                },
            }
    except Exception:
        pass
    return HARDCODED_MOCK


def call_pipeline_api_or_mock(query: str, db_id: str | None = None, timeout=10):
    """Call backend if available; otherwise return mock."""
    if USE_MOCK:
        return load_mock_from_summary()
    try:
        payload = {"query": query}
        if db_id:
            payload["db_id"] = db_id
        r = requests.post(API_QUERY, json=payload, timeout=timeout)
        r.raise_for_status()
        return r.json()
    except Exception as e:
        print(f"[demo] API call failed ({e}); using mock instead.")
        return load_mock_from_summary()


def upload_db(file_obj):
    if file_obj is None:
        return None, "No DB uploaded. Default DB will be used."
    name = getattr(file_obj, "name", "db.sqlite")
    if not (name.endswith(".db") or name.endswith(".sqlite")):
        return None, "Only .db or .sqlite files are allowed."
    size = getattr(file_obj, "size", None)
    if size and size > 20 * 1024 * 1024:
        return None, "File too large (>20MB) for this demo."

    files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
    try:
        r = requests.post(API_UPLOAD, files=files, timeout=120)
    finally:
        try:
            files["file"][1].close()
        except Exception:
            pass

    if r.ok:
        data = r.json()
        return data.get("db_id"), f"Uploaded OK. db_id={data.get('db_id')}"
    try:
        body = r.json()
    except ValueError:
        body = r.text
    return None, f"Upload failed ({r.status_code}): {body}"


def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
    if not user_query.strip():
        return "❌ Please enter a query.", "", "", {}, [], [], "", []

    data = call_pipeline_api_or_mock(user_query, db_id)
    sql = data.get("sql") or ""
    explanation = data.get("rationale") or ""
    result = data.get("result", {})
    trace_list = data.get("traces", [])

    metrics = data.get("metrics", {})
    badges_text = (
        f"EM={metrics.get('EM', '?')} | SM={metrics.get('SM', '?')} | "
        f"ExecAcc={metrics.get('ExecAcc', '?')} | latency={metrics.get('avg_latency_ms', '?')}ms"
    )

    timings_table = []
    if trace_list and all("duration_ms" in t for t in trace_list):
        timings_table = [[t["stage"], t["duration_ms"]] for t in trace_list]

    return badges_text, sql, explanation, result, trace_list, [], "", timings_table


def build_ui() -> gr.Blocks:
    with gr.Blocks(title="NL2SQL Copilot") as demo:
        gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.")
        db_state = gr.State(value=None)

        with gr.Row():
            db_file = gr.File(
                label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"]
            )
            upload_btn = gr.Button("Upload DB")
        db_msg = gr.Markdown()
        upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])

        with gr.Row():
            q = gr.Textbox(label="Question", scale=4)
            debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
            run = gr.Button("Run")

        badges = gr.Markdown()
        sql_out = gr.Code(label="Final SQL", language="sql")
        exp_out = gr.Textbox(label="Explanation", lines=3)

        with gr.Tab("Result"):
            res_out = gr.JSON()

        with gr.Tab("Trace"):
            trace = gr.JSON(label="Stage trace")

        with gr.Tab("Repair"):
            repair_candidates = gr.JSON(label="Candidates")
            repair_diff = gr.Textbox(label="Diff (if any)", lines=10)

        with gr.Tab("Timings"):
            timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"])

        run.click(
            query_to_sql,
            inputs=[q, db_state, debug],
            outputs=[
                badges,
                sql_out,
                exp_out,
                res_out,
                trace,
                repair_candidates,
                repair_diff,
                timings,
            ],
        )
    return demo


# expose for SDK mode (no Docker)
demo = build_ui()

if __name__ == "__main__":
    print("[demo] Launching Gradio demo on 0.0.0.0:7860 ...", flush=True)
    demo.launch(
        server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
        server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
        share=False,
        debug=True,
    )