File size: 8,540 Bytes
b260a50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581fba6
b260a50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581fba6
b260a50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import gradio as gr
import torch
import re
import tempfile
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from database import init_database, get_schema, execute_query

# Model Setup
MODEL_ID = "microsoft/tapex-large-sql-execution"
tokenizer = None
sql_pipeline = None

def load_model():
    global tokenizer, sql_pipeline
    print("microsoft/tapex-large-sql-execution ...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
        trust_remote_code=True,
    )
    sql_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=512,
        do_sample=False,
        return_full_text=False,
        pad_token_id=tokenizer.eos_token_id,
    )
    print("Model loaded.")


PROMPT_TEMPLATE = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Database Schema
The query will run on a database with the following schema:
{schema}

### Answer
Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
[SQL]
"""

def build_prompt(question: str, schema: str) -> str:
    return PROMPT_TEMPLATE.format(question=question, schema=schema)


def extract_sql(raw: str) -> str:
    match = re.search(r"(SELECT[\s\S]+?);", raw, re.IGNORECASE)
    if match:
        return match.group(0).strip()
    return raw.strip().split("[/SQL]")[0].strip()


def nl_to_sql_and_run(question: str, history: list):
    if not question.strip():
        yield history, "", gr.update(visible=False), gr.update(visible=False)
        return

    schema = get_schema()
    prompt = build_prompt(question, schema)

    yield history, "Generating SQL query...", gr.update(visible=False), gr.update(visible=False)

    try:
        output = sql_pipeline(prompt)[0]["generated_text"]
        sql = extract_sql(output)
    except Exception as e:
        new_hist = history + [{"role": "user", "content": question},
                              {"role": "assistant", "content": f"Model error: {e}"}]
        yield new_hist, "", gr.update(visible=False), gr.update(visible=False)
        return

    yield history, f"```sql\n{sql}\n```\n\nExecuting...", gr.update(visible=False), gr.update(visible=False)

    try:
        columns, rows = execute_query(sql)
    except Exception as e:
        answer = f"**Generated SQL:**\n```sql\n{sql}\n```\n\nExecution error: `{e}`"
        new_hist = history + [{"role": "user", "content": question},
                              {"role": "assistant", "content": answer}]
        yield new_hist, "", gr.update(visible=False), gr.update(visible=False)
        return

    if not rows:
        result_md = "*(query returned no rows)*"
        df = pd.DataFrame()
        csv_path = None
    else:
        df = pd.DataFrame(rows, columns=columns)
        result_md = df.to_markdown(index=False)
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="")
        df.to_csv(tmp.name, index=False)
        tmp.close()
        csv_path = tmp.name

    row_label = "rows" if len(rows) != 1 else "row"
    answer = f"**Generated SQL:**\n```sql\n{sql}\n```\n\n**Results ({len(rows)} {row_label}):**\n{result_md}"
    new_hist = history + [{"role": "user", "content": question},
                         {"role": "assistant", "content": answer}]

    yield (
        new_hist,
        "",
        gr.update(value=df, visible=bool(rows)),
        gr.update(value=csv_path, visible=bool(rows)),
    )


def view_schema():
    return f"```sql\n{get_schema()}\n```"


CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500&display=swap');

body, .gradio-container {
    background: #0d0f14 !important;
    font-family: 'DM Sans', sans-serif;
    color: #e2e8f0;
}

.title-block {
    text-align: center;
    padding: 2rem 0 1rem;
}

.title-block h1 {
    font-size: 2rem;
    background: linear-gradient(135deg, #38bdf8, #818cf8);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
    font-family: 'Space Mono', monospace;
    margin-bottom: 0.3rem;
}

.title-block p { color: #64748b; font-size: 0.95rem; }

.badge {
    display: inline-block;
    background: #1e2535;
    border: 1px solid #2d3748;
    border-radius: 20px;
    padding: 2px 12px;
    font-size: 0.75rem;
    color: #94a3b8;
    margin: 4px;
    font-family: 'Space Mono', monospace;
}
"""

EXAMPLE_QUERIES = [
    "Show me all employees in Engineering with salary above 120000",
    "Which department has the highest total salary budget?",
    "List all active projects with their budgets",
    "Who are the top 3 sales performers by total amount?",
    "How many employees are in each department?",
    "Show me all sales made in the East region in 2024",
]


def create_app():
    init_database()

    with gr.Blocks(css=CSS, title="SQLCoder Studio") as demo:

        gr.HTML("""
        <div class="title-block">
            <h1>SQLCoder Studio</h1>
            <p>Natural language to SQL to Results &nbsp;|&nbsp; Powered by microsoft/tapex-large-sql-execution</p>
            <div style="margin-top:0.8rem">
                <span class="badge">employees</span>
                <span class="badge">departments</span>
                <span class="badge">projects</span>
                <span class="badge">sales</span>
            </div>
        </div>
        """)

        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(
                    label="Conversation",
                    height=460,
                    show_label=False,
                    render_markdown=True,
                    bubble_full_width=False,
                    type="messages",
                )

                with gr.Row():
                    question_input = gr.Textbox(
                        placeholder="Ask anything about the database...",
                        show_label=False,
                        scale=5,
                        lines=1,
                    )
                    submit_btn = gr.Button("RUN", variant="primary", scale=1)

                with gr.Row():
                    clear_btn = gr.Button("Clear chat", variant="secondary", size="sm")

                gr.HTML("<p style='color:#475569;font-size:0.78rem;margin-top:0.5rem'>Try an example:</p>")
                example_btns = []
                with gr.Row(wrap=True):
                    for eq in EXAMPLE_QUERIES:
                        b = gr.Button(eq, size="sm", variant="secondary")
                        example_btns.append(b)

            with gr.Column(scale=2):
                gr.HTML("<p style='color:#94a3b8;font-size:0.85rem;font-weight:500;margin-bottom:4px'>Result Table</p>")
                result_table = gr.Dataframe(
                    visible=False,
                    wrap=True,
                    height=220,
                )
                download_file = gr.File(
                    label="Download CSV",
                    visible=False,
                )
                gr.HTML("<p style='color:#94a3b8;font-size:0.85rem;font-weight:500;margin:1rem 0 4px'>Database Schema</p>")
                gr.Markdown(value=view_schema())

        status_md = gr.Markdown(visible=False)
        history_state = gr.State([])

        def run(question, history):
            gen = nl_to_sql_and_run(question, history)
            for h, status, table_update, dl_update in gen:
                yield h, h, status, table_update, dl_update

        submit_btn.click(
            fn=run,
            inputs=[question_input, history_state],
            outputs=[chatbot, history_state, status_md, result_table, download_file],
        )
        question_input.submit(
            fn=run,
            inputs=[question_input, history_state],
            outputs=[chatbot, history_state, status_md, result_table, download_file],
        )
        clear_btn.click(
            fn=lambda: ([], [], "", gr.update(visible=False), gr.update(visible=False)),
            outputs=[chatbot, history_state, status_md, result_table, download_file],
        )

        for btn, eq in zip(example_btns, EXAMPLE_QUERIES):
            btn.click(fn=lambda q=eq: q, outputs=[question_input])

    return demo


if __name__ == "__main__":
    load_model()
    app = create_app()
    app.launch()