File size: 6,707 Bytes
b568b83
 
aac8545
 
 
b568b83
3c2f1c5
 
 
 
aac8545
3c2f1c5
 
b568b83
aac8545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
 
 
 
 
 
 
 
 
1c9c65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac8545
1c9c65e
 
 
 
 
 
 
 
aac8545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
aac8545
 
 
1c9c65e
aac8545
b568b83
 
aac8545
b568b83
 
 
 
 
 
5cbfffe
 
 
b568b83
 
 
 
 
 
1c9c65e
b568b83
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c85f7
aac8545
b568b83
1c9c65e
b568b83
 
 
 
5cbfffe
 
 
 
 
 
 
 
 
 
b568b83
 
 
98694e9
 
c8b0bcb
98694e9
 
 
c8b0bcb
 
98694e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import requests
import gradio as gr
import os
import json
from pathlib import Path

# Prefer internal backend when running inside Docker
API_HOST = os.getenv("API_HOST", "localhost")
API_PORT = os.getenv("API_PORT", "8000")

USE_MOCK = os.environ.get("USE_MOCK", "0") == "1"
API_UPLOAD = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql/upload_db"
API_QUERY = f"http://{API_HOST}:{API_PORT}/api/v1/nl2sql"

HARDCODED_MOCK = {
    "sql": "SELECT name, country FROM singer WHERE age > 20;",
    "rationale": "Example: select singers older than 20.",
    "result": {
        "rows": 5,
        "columns": ["name", "country"],
        "rows_data": [["Alice", "France"], ["Bob", "USA"]],
    },
    "traces": [
        {"stage": "detector", "summary": "ok", "duration_ms": 5},
        {"stage": "planner", "summary": "intent parsed", "duration_ms": 120},
        {"stage": "generator", "summary": "sql generated", "duration_ms": 420},
        {"stage": "verifier", "summary": "passed", "duration_ms": 10},
    ],
    "metrics": {"EM": 0.15, "SM": 0.70, "ExecAcc": 0.73, "avg_latency_ms": 8113},
}


def load_mock_from_summary():
    """Try to read latest benchmark summary.json; fallback to hardcoded mock."""
    try:
        files = sorted(
            Path("benchmarks/results_pro").glob("*/summary.json"),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )
        if files:
            p = files[0]
            with open(p, "r", encoding="utf-8") as f:
                sj = json.load(f)
            return {
                "sql": sj.get("example_sql", HARDCODED_MOCK["sql"]),
                "rationale": sj.get("note", HARDCODED_MOCK["rationale"]),
                "result": {"rows": sj.get("total_samples", 0), "columns": []},
                "traces": HARDCODED_MOCK["traces"],
                "metrics": {
                    "EM": sj.get("avg_em", HARDCODED_MOCK["metrics"]["EM"]),
                    "SM": sj.get("avg_sm", HARDCODED_MOCK["metrics"]["SM"]),
                    "ExecAcc": sj.get(
                        "avg_execacc", HARDCODED_MOCK["metrics"]["ExecAcc"]
                    ),
                    "avg_latency_ms": sj.get(
                        "avg_latency_ms", HARDCODED_MOCK["metrics"]["avg_latency_ms"]
                    ),
                },
            }
    except Exception:
        pass
    return HARDCODED_MOCK


def call_pipeline_api_or_mock(query: str, db_id: str | None = None, timeout=10):
    """Call backend if available; otherwise return mock."""
    if USE_MOCK:
        return load_mock_from_summary()
    try:
        payload = {"query": query}
        if db_id:
            payload["db_id"] = db_id
        r = requests.post(API_QUERY, json=payload, timeout=timeout)
        r.raise_for_status()
        return r.json()
    except Exception as e:
        print(f"[demo] API call failed ({e}); using mock instead.")
        return load_mock_from_summary()


def upload_db(file_obj):
    if file_obj is None:
        return None, "No DB uploaded. Default DB will be used."
    name = getattr(file_obj, "name", "db.sqlite")
    if not (name.endswith(".db") or name.endswith(".sqlite")):
        return None, "Only .db or .sqlite files are allowed."
    size = getattr(file_obj, "size", None)
    if size and size > 20 * 1024 * 1024:
        return None, "File too large (>20MB) for this demo."

    # Gradio gives a temp file path as file_obj.name
    files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
    try:
        r = requests.post(API_UPLOAD, files=files, timeout=120)
    finally:
        # best-effort close
        try:
            files["file"][1].close()
        except Exception:
            pass

    if r.ok:
        data = r.json()
        return data.get("db_id"), f"Uploaded OK. db_id={data.get('db_id')}"
    try:
        body = r.json()
    except ValueError:
        body = r.text
    return None, f"Upload failed ({r.status_code}): {body}"


def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
    """Unified query handler: tries backend or mock fallback."""
    if not user_query.strip():
        return "❌ Please enter a query.", "", "", {}, [], [], "", []

    data = call_pipeline_api_or_mock(user_query, db_id)
    sql = data.get("sql") or ""
    explanation = data.get("rationale") or ""
    result = data.get("result", {})
    trace_list = data.get("traces", [])

    metrics = data.get("metrics", {})
    badges_text = (
        f"EM={metrics.get('EM', '?')} | SM={metrics.get('SM', '?')} | "
        f"ExecAcc={metrics.get('ExecAcc', '?')} | latency={metrics.get('avg_latency_ms', '?')}ms"
    )

    timings_table = []
    if trace_list and all("duration_ms" in t for t in trace_list):
        timings_table = [[t["stage"], t["duration_ms"]] for t in trace_list]

    return badges_text, sql, explanation, result, trace_list, [], "", timings_table


# ---- UI definition (unchanged) ----
with gr.Blocks(title="NL2SQL Copilot") as demo:
    gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.")

    db_state = gr.State(value=None)

    with gr.Row():
        db_file = gr.File(
            label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"]
        )
        upload_btn = gr.Button("Upload DB")
    db_msg = gr.Markdown()
    upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])

    with gr.Row():
        q = gr.Textbox(label="Question", scale=4)
        debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
        run = gr.Button("Run")

    badges = gr.Markdown()
    sql_out = gr.Code(label="Final SQL", language="sql")
    exp_out = gr.Textbox(label="Explanation", lines=3)

    with gr.Tab("Result"):
        res_out = gr.JSON()

    with gr.Tab("Trace"):
        trace = gr.JSON(label="Stage trace")

    with gr.Tab("Repair"):
        repair_candidates = gr.JSON(label="Candidates")
        repair_diff = gr.Textbox(label="Diff (if any)", lines=10)

    with gr.Tab("Timings"):
        timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"])

    run.click(
        query_to_sql,
        inputs=[q, db_state, debug],
        outputs=[
            badges,
            sql_out,
            exp_out,
            res_out,
            trace,
            repair_candidates,
            repair_diff,
            timings,
        ],
    )

if __name__ == "__main__":
    import os

    print("[demo] Launching Gradio demo on 0.0.0.0:7860 ...", flush=True)
    demo.launch(
        server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
        server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
        share=False,
        debug=True,
    )