File size: 5,649 Bytes
b568b83
 
 
 
5cbfffe
b568b83
 
 
 
 
 
 
 
 
 
1c9c65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
1c9c65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b568b83
1c9c65e
b568b83
1c9c65e
 
 
 
 
 
 
5cbfffe
 
 
 
 
 
1c9c65e
b568b83
 
1c9c65e
b568b83
 
1c9c65e
 
b568b83
1c9c65e
b568b83
 
 
1c9c65e
b568b83
 
 
 
 
 
 
 
 
 
 
 
5cbfffe
 
 
b568b83
 
 
 
 
 
1c9c65e
 
b568b83
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c85f7
b568b83
1c9c65e
b568b83
 
 
 
5cbfffe
 
 
 
 
 
 
 
 
 
b568b83
 
 
 
5cbfffe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import requests
import gradio as gr

API_UPLOAD = "http://localhost:8000/api/v1/nl2sql/upload_db"
API_QUERY = "http://localhost:8000/api/v1/nl2sql"


def upload_db(file_obj):
    if file_obj is None:
        return None, "No DB uploaded. Default DB will be used."
    name = getattr(file_obj, "name", "db.sqlite")
    if not (name.endswith(".db") or name.endswith(".sqlite")):
        return None, "Only .db or .sqlite files are allowed."
    size = getattr(file_obj, "size", None)
    if size and size > 20 * 1024 * 1024:
        return None, "File too large (>20MB) for this demo."

    # Gradio gives a temp file path as file_obj.name
    files = {"file": (name, open(file_obj.name, "rb"), "application/octet-stream")}
    try:
        r = requests.post(API_UPLOAD, files=files, timeout=120)
    finally:
        # best-effort close
        try:
            files["file"][1].close()
        except Exception:
            pass

    if r.ok:
        data = r.json()
        db_id = data.get("db_id")
        if not db_id:
            return None, f"Upload returned no db_id: {data}"
        return db_id, f"Uploaded OK. db_id={db_id}"
    # Show backend error body
    try:
        body = r.json()
    except ValueError:
        body = r.text
    return None, f"Upload failed ({r.status_code}): {body}"


def _post_query(payload: dict):
    """Helper: POST and return (ok, data_or_error_string)."""
    r = requests.post(API_QUERY, json=payload, timeout=120)
    if r.ok:
        try:
            return True, r.json()
        except ValueError:
            return False, "Backend returned non-JSON body."
    try:
        body = r.json()
    except ValueError:
        body = r.text
    return False, f"{r.status_code} {body}"


def query_to_sql(user_query: str, db_id: str | None, _debug_flag: bool):
    # Build minimal schema-compliant payload.
    # Server expects request.query (name is 'query' per router code).
    base_payload = {"query": user_query.strip() if user_query else ""}

    # First try WITH db_id (if present). If backend rejects (422), retry WITHOUT.
    if db_id:
        ok, data = _post_query({**base_payload, "db_id": db_id})
        if not ok and isinstance(data, str) and data.startswith("422"):
            # Retry without db_id in case request model forbids extra fields.
            ok, data = _post_query(base_payload)
    else:
        ok, data = _post_query(base_payload)

    if not ok:
        # Surface backend error text to the UI
        err_badges = f"Error: {data}"
        return (
            err_badges,  # badges
            "",  # sql_out
            "",  # exp_out
            {},  # result (tab)
            [],  # trace (tab)
            [],  # repair_candidates (tab)
            "",  # repair_diff (tab)
            [],  # timings (tab)
        )

    d = data

    # Map fields to UI (server returns: ambiguous, sql, rationale, traces)
    sql = d.get("sql") or d.get("sql_final") or ""
    explanation = d.get("rationale") or d.get("explanation") or ""
    result = d.get("result", {})  # optional/maybe absent
    trace_list = d.get("traces") or d.get("trace") or []

    ambiguous_flag = "Yes" if d.get("ambiguous") else "No"
    safety = (
        "Allowed"
        if d.get("safety", {}).get("allowed", True)
        else f"Blocked: {d.get('safety', {}).get('blocked_reason')}"
    )
    verification = "Passed" if d.get("verification", {}).get("passed") else "Failed"
    repair = d.get("repair", {}) or {}
    repair_text = f"Applied: {repair.get('applied', False)}, Attempts: {repair.get('attempts', 0)}"

    timings = d.get("timings_ms", {}) or {}
    timings_table = [[k, timings[k]] for k in sorted(timings.keys())]

    badges_text = f"Ambiguous: {ambiguous_flag} | Safety: {safety} | Verification: {verification} | Repair: {repair_text}"

    return (
        badges_text,
        sql,
        explanation,
        result,
        trace_list,
        repair.get("candidates", []),
        repair.get("diff", ""),
        timings_table,
    )


with gr.Blocks(title="NL2SQL Copilot") as demo:
    gr.Markdown("# NL2SQL Copilot\nUpload a SQLite DB (optional) or use default.")

    db_state = gr.State(value=None)

    with gr.Row():
        db_file = gr.File(
            label="Upload SQLite (.db/.sqlite)", file_types=[".db", ".sqlite"]
        )
        upload_btn = gr.Button("Upload DB")
    db_msg = gr.Markdown()
    upload_btn.click(upload_db, inputs=[db_file], outputs=[db_state, db_msg])

    with gr.Row():
        q = gr.Textbox(label="Question", scale=4)
        # keep the checkbox in UI if you like, but we don't send it to backend
        debug = gr.Checkbox(label="Debug (UI only)", value=True, scale=1)
        run = gr.Button("Run")

    badges = gr.Markdown()
    sql_out = gr.Code(label="Final SQL", language="sql")
    exp_out = gr.Textbox(label="Explanation", lines=3)

    with gr.Tab("Result"):
        res_out = gr.JSON()

    with gr.Tab("Trace"):
        trace = gr.JSON(label="Stage trace")

    with gr.Tab("Repair"):
        repair_candidates = gr.JSON(label="Candidates")
        repair_diff = gr.Textbox(label="Diff (if any)", lines=10)
    with gr.Tab("Timings"):
        timings = gr.Dataframe(headers=["metric", "ms"], datatype=["str", "number"])

    run.click(
        query_to_sql,
        inputs=[q, db_state, debug],
        outputs=[
            badges,
            sql_out,
            exp_out,
            res_out,
            trace,
            repair_candidates,
            repair_diff,
            timings,
        ],
    )

if __name__ == "__main__":
    # Let Gradio pick a free port by default to avoid collisions
    demo.launch()