File size: 8,707 Bytes
aac8545
cc371b0
 
 
 
 
b568b83
cc371b0
3c2f1c5
 
cc371b0
 
 
 
 
3c2f1c5
aac8545
cc371b0
 
 
 
 
 
 
b568b83
cc371b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
cc371b0
 
b568b83
 
 
cc371b0
b568b83
 
1c9c65e
 
cc371b0
1c9c65e
cc371b0
 
 
 
 
 
 
 
 
 
 
 
1c9c65e
 
cc371b0
1c9c65e
 
 
cc371b0
 
 
 
 
 
 
 
 
 
 
 
 
 
1c9c65e
cc371b0
 
 
 
 
 
 
1c9c65e
cc371b0
 
 
aac8545
cc371b0
 
 
 
 
 
 
 
 
aac8545
cc371b0
 
aac8545
cc371b0
aac8545
cc371b0
 
 
 
 
b568b83
cc371b0
 
 
 
 
 
1c9c65e
cc371b0
b568b83
 
a5a91fa
cc371b0
 
 
 
 
 
 
 
a5a91fa
cc371b0
 
 
 
 
a5a91fa
 
cc371b0
a5a91fa
 
cc371b0
 
a5a91fa
 
cc371b0
a5a91fa
cc371b0
 
 
 
 
a5a91fa
cc371b0
a5a91fa
cc371b0
 
 
 
 
 
 
 
 
 
a5a91fa
 
cc371b0
c9bbfcd
 
 
b695359
 
 
c9bbfcd
b695359
 
c9bbfcd
 
 
 
 
a5a91fa
 
cc371b0
a5a91fa
 
 
 
 
cc371b0
a5a91fa
 
c9bbfcd
 
 
 
 
 
 
 
 
 
 
cc371b0
 
c9bbfcd
 
 
 
 
a5a91fa
 
c9bbfcd
cc371b0
 
c9bbfcd
a5a91fa
 
 
 
cc371b0
a5a91fa
cc371b0
a5a91fa
b568b83
 
a5a91fa
b568b83
 
c8b0bcb
98694e9
 
 
c8b0bcb
 
98694e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import os
from typing import Any, Dict, List, Optional, Tuple

import gradio as gr
import requests
from requests.exceptions import ConnectionError, RequestException, Timeout

# Backend configuration
API_HOST = os.getenv("API_HOST", "localhost")
API_PORT = os.getenv("API_PORT", "8000")
API_BASE = f"http://{API_HOST}:{API_PORT}"

API_QUERY = f"{API_BASE}/api/v1/nl2sql"
API_UPLOAD = f"{API_BASE}/api/v1/nl2sql/upload_db"
API_KEY = os.getenv("API_KEY", "dev-key")  # align with backend API_KEYS env


def call_pipeline_api(
    query: str,
    db_id: Optional[str] = None,
    timeout: int = 30,
) -> Dict[str, Any]:
    """
    Call the real FastAPI backend. No mock, no silent fallback.

    If db_id is None, the backend will use its default database.
    Any connection or HTTP error is surfaced back to the UI as an error payload.
    """
    payload: Dict[str, Any] = {"query": query}
    if db_id:
        payload["db_id"] = db_id

    headers: Dict[str, str] = {"Content-Type": "application/json"}
    if API_KEY:
        headers["X-API-Key"] = API_KEY

    try:
        resp = requests.post(API_QUERY, json=payload, headers=headers, timeout=timeout)
        resp.raise_for_status()
        return resp.json()
    except (ConnectionError, Timeout) as e:
        msg = f"Backend not reachable: {e}"
        print(f"[demo] {msg}", flush=True)
        return {
            "sql": "",
            "rationale": msg,
            "result": {},
            "traces": [],
            "error": msg,
        }
    except RequestException:
        try:
            body: Any = resp.json()
        except Exception:
            body = resp.text
        msg = f"Backend error {resp.status_code}: {body}"
        print(f"[demo] {msg}", flush=True)
        return {
            "sql": "",
            "rationale": msg,
            "result": {},
            "traces": [],
            "error": msg,
        }


def upload_db(file_obj: Any) -> Tuple[Optional[str], str]:
    """
    Upload a SQLite database to the backend and return (db_id, message).

    The returned db_id is stored in Gradio state and used for subsequent queries.
    """
    if file_obj is None:
        return None, "No DB uploaded. The backend default DB will be used."

    name = getattr(file_obj, "name", "db.sqlite")
    if not (name.endswith(".db") or name.endswith(".sqlite")):
        return None, "Only .db or .sqlite files are allowed."

    size = getattr(file_obj, "size", None)
    if size and size > 20 * 1024 * 1024:
        return None, "File too large (>20MB) for this demo."

    # Gradio's File component provides a temporary file on disk.
    try:
        f = open(file_obj.name, "rb")
    except Exception as e:
        return None, f"Could not open uploaded file: {e}"

    files = {"file": (os.path.basename(name), f, "application/octet-stream")}

    headers: Dict[str, str] = {}
    if API_KEY:
        headers["X-API-Key"] = API_KEY

    try:
        resp = requests.post(API_UPLOAD, files=files, headers=headers, timeout=120)
    finally:
        try:
            f.close()
        except Exception:
            pass

    if resp.ok:
        try:
            data = resp.json()
        except Exception:
            return None, f"Upload succeeded but response was not JSON: {resp.text}"
        db_id = data.get("db_id")
        return db_id, f"Uploaded OK. db_id={db_id}"
    else:
        try:
            body = resp.json()
        except Exception:
            body = resp.text
        return None, f"Upload failed ({resp.status_code}): {body}"


def query_to_sql(
    user_query: str,
    db_id: Optional[str],
    _debug_flag: bool,
) -> Tuple[str, str, str, Any, List[Dict[str, Any]], List[List[Any]]]:
    """
    Run the full NL2SQL pipeline via the backend and format outputs for the UI.

    Returns:
        badges_text, sql, explanation, result_json, traces_json, timings_table
    """
    if not user_query.strip():
        msg = "❌ Please enter a query."
        return msg, "", msg, {}, [], []

    data = call_pipeline_api(user_query, db_id)

    # Explicit error propagation from backend
    if data.get("error") and not data.get("sql"):
        err_msg = str(data.get("error"))
        return f"❌ {err_msg}", "", err_msg, {}, [], []

    sql = str(data.get("sql") or "")
    explanation = str(data.get("rationale") or "")
    result = data.get("result", {})
    traces = data.get("traces", []) or []

    # Compute simple latency badge from traces (sum of duration_ms)
    badges_text = ""
    if traces and all("duration_ms" in t for t in traces):
        total_ms = sum(float(t.get("duration_ms", 0.0)) for t in traces)
        badges_text = f"latency≈{int(total_ms)}ms"

    # Build timings table for the Timings tab
    timings_table: List[List[Any]] = []
    if traces and all("duration_ms" in t for t in traces):
        timings_table = [
            [t.get("stage", "?"), t.get("duration_ms", 0.0)] for t in traces
        ]

    return badges_text, sql, explanation, result, traces, timings_table


def build_ui() -> gr.Blocks:
    """
    Build the Gradio UI for the NL2SQL Copilot demo.

    - Optional DB upload (SQLite)
    - Textbox for the natural language question
    - Example queries aligned with the default Chinook DB
    - Tabs for result, trace, repair notes, and per-stage timings
    """
    with gr.Blocks(title="NL2SQL Copilot") as demo:
        gr.Markdown(
            "# NL2SQL Copilot\n"
            "Upload a SQLite DB (optional) or use the backend default database."
        )

        db_state = gr.State(value=None)

        # DB upload section
        with gr.Row():
            db_file = gr.File(
                label="Upload SQLite (.db/.sqlite)",
                file_types=[".db", ".sqlite"],
            )
            upload_btn = gr.Button("Upload DB")

        db_msg = gr.Markdown()
        upload_btn.click(
            upload_db,
            inputs=[db_file],
            outputs=[db_state, db_msg],
        )

        # Query input and run button
        with gr.Row():
            q = gr.Textbox(
                label="Question",
                placeholder="e.g. Top 3 albums by total sales",
                scale=4,
            )
            debug = gr.Checkbox(
                label="Debug (UI only)",
                value=True,
                scale=1,
            )
            run = gr.Button("Run")

        # Example queries compatible with the Chinook schema
        gr.Examples(
            examples=[
                ["List all artists"],
                [
                    "List customers whose total spending is above the average invoice total."
                ],
                ["Total number of tracks per genre"],
                ["List all albums with their total sales"],
                ["Customers spending above average"],
            ],
            inputs=[q],
            label="Try these example queries",
        )

        badges = gr.Markdown()
        sql_out = gr.Code(label="Final SQL", language="sql")
        exp_out = gr.Textbox(label="Explanation", lines=4)

        with gr.Tab("Result"):
            res_out = gr.JSON()

        with gr.Tab("Trace"):
            trace_out = gr.JSON(label="Stage trace")

        with gr.Tab("Repair"):
            gr.Markdown(
                """
                ### Repair & self-healing (pipeline-level)

                The repair loop is fully implemented in the backend:

                * If a candidate SQL fails safety or execution checks,
                  the pipeline attempts to **repair** it.
                * All repair attempts and outcomes are tracked in Prometheus
                  (for example, `nl2sql_repair_attempts_total` and related rates).

                For now, detailed before/after SQL diffs and repair candidates
                are exposed via traces and metrics dashboards.

                This tab is reserved for a future, richer UI:
                side-by-side SQL diff, repair candidates, and explanations.
                """
            )

        with gr.Tab("Timings"):
            timings = gr.Dataframe(
                headers=["stage", "duration_ms"],
                datatype=["str", "number"],
            )

        run.click(
            query_to_sql,
            inputs=[q, db_state, debug],
            outputs=[badges, sql_out, exp_out, res_out, trace_out, timings],
        )

    return demo


demo = build_ui()

if __name__ == "__main__":
    print("[demo] Launching Gradio demo on 0.0.0.0:7860 ...", flush=True)
    demo.launch(
        server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
        server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
        share=False,
        debug=True,
    )