File size: 18,926 Bytes
fe4cc52
735b5da
fe4cc52
 
 
f0e347c
fe4cc52
 
 
 
 
 
 
 
 
9572f14
f0e347c
9572f14
 
735b5da
 
 
06e0188
9572f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e0188
9572f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb06bc
 
 
 
9572f14
 
 
2bb06bc
9572f14
 
 
 
 
 
 
 
2bb06bc
9572f14
 
 
 
 
 
 
 
 
 
fe4cc52
 
06e0188
 
 
 
 
 
 
f0e347c
06e0188
 
 
 
 
 
 
f0e347c
06e0188
 
f0e347c
06e0188
 
 
f0e347c
 
3fb0154
f0e347c
 
 
 
9572f14
f0e347c
 
3fb0154
 
 
f0e347c
 
 
fe4cc52
 
 
 
 
f0e347c
fe4cc52
 
 
 
 
 
 
 
f0e347c
fe4cc52
9572f14
 
 
 
 
 
f0e347c
06e0188
f0e347c
3fb0154
 
06e0188
 
10b5b90
8eedef4
06e0188
8eedef4
06e0188
 
10b5b90
8eedef4
10b5b90
8eedef4
 
10b5b90
8eedef4
 
10b5b90
 
8eedef4
 
10b5b90
8eedef4
10b5b90
8eedef4
 
10b5b90
8eedef4
 
 
 
 
fe4cc52
 
3fb0154
f0e347c
3fb0154
06e0188
3fb0154
 
06e0188
f0e347c
3fb0154
f0e347c
3fb0154
f0e347c
06e0188
f0e347c
3fb0154
f0e347c
 
 
 
 
3fb0154
f0e347c
 
 
 
 
3fb0154
f0e347c
 
 
 
 
 
 
 
 
 
 
 
 
3fb0154
f0e347c
 
 
 
 
 
 
 
3fb0154
f0e347c
 
 
3fb0154
 
 
f0e347c
3fb0154
 
f0e347c
3fb0154
f0e347c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e0188
 
3fb0154
f0e347c
 
3fb0154
 
 
 
 
 
 
 
f0e347c
3fb0154
 
 
 
f0e347c
3fb0154
 
 
 
f0e347c
 
3fb0154
 
 
 
 
f0e347c
 
 
3fb0154
 
 
f0e347c
 
3fb0154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0e347c
 
3fb0154
f0e347c
 
3fb0154
 
 
 
f0e347c
3fb0154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0e347c
 
 
06e0188
 
f0e347c
3fb0154
f0e347c
3fb0154
f0e347c
3fb0154
06e0188
 
 
76545d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import re
import os
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import time

MODEL_ID = "RealMati/t2sql_v6_structured"

print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
model.eval()
print("Model loaded.")

AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
AGG_LABELS = ["None", "MAX", "MIN", "COUNT", "SUM", "AVG"]
OPS = ["=", ">", "<", ">=", "<=", "!="]

css_path = os.path.join(os.path.dirname(__file__), "style.css")
with open(css_path, "r") as f:
    CSS = f.read()


def decode_structured_output(text):
    sel = agg = None
    conds = []
    try:
        for part in text.strip().split(" | "):
            part = part.strip()
            if part.startswith("SEL:"):
                sel = int(part[4:].strip())
            elif part.startswith("AGG:"):
                agg = int(part[4:].strip())
            elif part.startswith("CONDS:"):
                cond_str = part[6:].strip()
                if cond_str:
                    for c in cond_str.split(";"):
                        vals = c.split(",", 2)
                        if len(vals) >= 3:
                            conds.append([int(vals[0]), int(vals[1]), vals[2]])
    except Exception:
        pass
    return sel, agg, conds


def parse_schema(schema_str):
    schema_str = schema_str.strip()
    if not schema_str:
        return "table", []
    first_table = schema_str.split("|")[0].strip()
    if ":" in first_table:
        table_name, cols_str = first_table.split(":", 1)
        table_name = table_name.strip()
        columns = [c.strip() for c in cols_str.split(",") if c.strip()]
    else:
        table_name = "table"
        columns = [c.strip() for c in first_table.split(",") if c.strip()]
    return table_name, columns


def quote_col(name):
    return f"`{name}`" if " " in name else name


def structured_to_sql(sel, agg, conds, columns, table_name="table"):
    if sel is None or agg is None:
        return None
    col_name = quote_col(columns[sel] if sel < len(columns) else f"col{sel}")
    if agg == 0:
        sql = f"SELECT {col_name} FROM {table_name}"
    else:
        agg_op = AGG_OPS[agg] if agg < len(AGG_OPS) else ""
        sql = f"SELECT {agg_op}({col_name}) FROM {table_name}"
    if conds:
        where_parts = []
        for c_idx, c_op, c_val in conds:
            c_name = quote_col(columns[c_idx] if c_idx < len(columns) else f"col{c_idx}")
            op_str = OPS[c_op] if c_op < len(OPS) else "="
            try:
                float(c_val)
                val_sql = c_val
            except (ValueError, TypeError):
                val_sql = f"'{c_val}'"
            where_parts.append(f"{c_name} {op_str} {val_sql}")
        if where_parts:
            sql += " WHERE " + " AND ".join(where_parts)
    return sql


def format_parsed(sel, agg, conds, columns):
    parts = []
    if sel is not None and sel < len(columns):
        parts.append(f"Column: {columns[sel]} (index {sel})")
    elif sel is not None:
        parts.append(f"Column index: {sel}")
    if agg is not None:
        agg_label = AGG_LABELS[agg] if agg < len(AGG_LABELS) else str(agg)
        parts.append(f"Aggregation: {agg_label}")
    if conds:
        cond_strs = []
        for c_idx, c_op, c_val in conds:
            c_name = columns[c_idx] if c_idx < len(columns) else f"col{c_idx}"
            op_str = OPS[c_op] if c_op < len(OPS) else "="
            cond_strs.append(f"{c_name} {op_str} {c_val}")
        parts.append(f"Conditions: {' AND '.join(cond_strs)}")
    else:
        parts.append("Conditions: None")
    return "  |  ".join(parts)


def predict(question, schema, num_beams, max_length):
    if not question or not question.strip():
        return (
            "-- Enter a question and schema, then click Generate SQL",
            "Waiting for input...",
            "No query submitted yet",
            "",
        )
    table_name, columns = parse_schema(schema)
    if not columns:
        return (
            "-- Please provide a schema\n-- Format: table_name: col1, col2, col3",
            "Schema required",
            "Cannot map indices without column names",
            "",
        )

    input_text = f"translate to SQL: {question}"
    if schema.strip():
        input_text += f" | schema: {schema.strip()}"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)

    t0 = time.time()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=int(max_length),
            num_beams=int(num_beams),
            early_stopping=True,
            do_sample=False,
        )
    latency = time.time() - t0

    raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sel, agg, conds = decode_structured_output(raw_output)

    if sel is not None and agg is not None and columns:
        sql = structured_to_sql(sel, agg, conds, columns, table_name)
    else:
        sql = f"-- Could not parse model output\n-- Raw: {raw_output}"

    parsed = format_parsed(sel, agg, conds, columns) if sel is not None else "Parse failed"
    perf = f"Inference: {latency:.2f}s  |  Beams: {int(num_beams)}  |  Tokens: {inputs['input_ids'].shape[1]}"
    return sql, raw_output, parsed, perf


theme = gr.themes.Base(
    primary_hue="blue",
    secondary_hue="purple",
    neutral_hue="gray",
    font=gr.themes.GoogleFont("Inter"),
    font_mono=gr.themes.GoogleFont("Fira Code"),
).set(
    body_background_fill="#0d1117",
    body_text_color="#e2e8f0",
    block_background_fill="#161b22",
    block_border_color="#1f2937",
    block_border_width="1px",
    block_label_text_color="#d1d5db",
    block_title_text_color="#f3f4f6",
    block_radius="12px",
    block_shadow="none",
    input_background_fill="#111827",
    input_border_color="#1f2937",
    input_border_width="1px",
    input_placeholder_color="#4b5563",
    input_radius="8px",
    slider_color="#3b82f6",
    button_primary_background_fill="linear-gradient(135deg, #3b82f6, #8b5cf6)",
    button_primary_text_color="#ffffff",
    button_secondary_background_fill="#111827",
    button_secondary_text_color="#d1d5db",
    button_secondary_border_color="#1f2937",
    border_color_primary="#1f2937",
    color_accent_soft="#111827",
)

with gr.Blocks(title="Text-to-SQL | T5 on WikiSQL") as demo:

    # Compact header β€” one line title + badges + pipeline
    gr.HTML("""
    <div class="app-header">
        <h1><span>Text-to-SQL</span></h1>
    </div>
    <div class="tech-badges">
        <span class="badge badge-indigo">T5-base (220M)</span>
        <span class="badge badge-purple">Seq2Seq</span>
        <span class="badge badge-emerald">WikiSQL 80K+</span>
        <span class="badge badge-amber">Structured Output</span>
    </div>
    <div class="pipeline-strip">
        <span class="step step-input">Question</span>
        <span class="arrow">&rarr;</span>
        <span class="step step-model">T5 Encoder-Decoder</span>
        <span class="arrow">&rarr;</span>
        <span class="step step-struct">SEL | AGG | CONDS</span>
        <span class="arrow">&rarr;</span>
        <span class="step step-sql">SQL</span>
    </div>
    """)

    with gr.Tabs():

        # ══ TAB 1: INFERENCE (main focus) ══
        with gr.Tab("Demo"):
            with gr.Row(equal_height=False):
                with gr.Column(scale=1):
                    question = gr.Textbox(
                        label="Natural Language Question",
                        placeholder="e.g. What is terrence ross' nationality?",
                        lines=2,
                    )
                    schema = gr.Textbox(
                        label="Database Schema",
                        placeholder="table_name: col1, col2, col3, ...",
                        lines=2,
                    )
                    gr.HTML('<p class="input-hint">Format: <code>table: col1, col2, col3</code> β€” column order = index mapping</p>')
                    with gr.Row():
                        beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Beam Size")
                        max_len = gr.Slider(minimum=64, maximum=512, value=256, step=64, label="Max Length")
                    btn = gr.Button("Generate SQL", variant="primary", elem_classes=["generate-btn"], size="lg")

                with gr.Column(scale=1):
                    sql_out = gr.Textbox(
                        label="Generated SQL",
                        value="-- Enter a question and schema, then click Generate SQL",
                        lines=3,
                        elem_classes=["sql-output"],
                    )
                    raw_out = gr.Textbox(label="Raw Structured Tokens", value="Waiting for input...", lines=1, elem_classes=["decode-box"])
                    parsed_out = gr.Textbox(label="Decoded Mapping", value="No query submitted yet", lines=1, elem_classes=["decode-box"])
                    latency_out = gr.Textbox(label="Performance", value="", lines=1, elem_classes=["decode-box"])

            btn.click(fn=predict, inputs=[question, schema, beams, max_len], outputs=[sql_out, raw_out, parsed_out, latency_out])
            question.submit(fn=predict, inputs=[question, schema, beams, max_len], outputs=[sql_out, raw_out, parsed_out, latency_out])

            gr.Markdown("#### Examples")
            gr.Examples(
                examples=[
                    ["What is terrence ross' nationality", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
                    ["how many schools or teams had jalen rose", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
                    ["What was the date of the race in Misano?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
                    ["What was the number of race that Kevin Curtain won?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
                    ["Where was Assen held?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
                    ["How many different positions did Sherbrooke Faucons (qmjhl) provide in the draft?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
                    ["What are the nationalities of the player picked from Thunder Bay Flyers (ushl)", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
                    ["How many different nationalities do the players of New Jersey Devils come from?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
                    ["What's Dorain Anneck's pick number?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
                ],
                inputs=[question, schema, beams, max_len],
                outputs=[sql_out, raw_out, parsed_out, latency_out],
                fn=predict,
                cache_examples=False,
            )

        # ══ TAB 2: HOW IT WORKS ══
        with gr.Tab("How It Works"):
            gr.HTML("""
            <div class="arch-card">
                <h3>Architecture</h3>
                <p>A <strong>T5-base</strong> encoder-decoder fine-tuned on WikiSQL.
                Instead of generating raw SQL, it outputs <strong>structured tokens</strong>
                β€” column indices and operator codes β€” which a deterministic decoder
                maps to actual SQL using the provided schema.</p>
            </div>
            <div class="arch-grid">
                <div class="arch-card">
                    <h3>1. Input Encoding</h3>
                    <p>Question + schema concatenated:</p>
                    <p><code>translate to SQL: {question} | schema: {table}: {col1}, {col2}</code></p>
                    <p>Column order matters β€” the model references columns by <strong>0-based index</strong>.</p>
                </div>
                <div class="arch-card">
                    <h3>2. T5 Generation</h3>
                    <p>The encoder processes input, decoder generates structured tokens via beam search.</p>
                    <p>Output: <code>SEL:{col} | AGG:{agg} | CONDS:{col},{op},{val}</code></p>
                </div>
                <div class="arch-card">
                    <h3>3. Structured Decoding</h3>
                    <ul style="margin:0.4rem 0;padding-left:1.2rem;">
                        <li><strong>SEL</strong> β€” column index to SELECT</li>
                        <li><strong>AGG</strong> β€” aggregation (0=none, 3=COUNT, etc.)</li>
                        <li><strong>CONDS</strong> β€” WHERE conditions as <code>col,op,value</code> tuples</li>
                    </ul>
                </div>
                <div class="arch-card">
                    <h3>4. SQL Assembly</h3>
                    <p>Indices mapped back to column names from schema. Operators converted to SQL.
                    Result: a valid, executable query.</p>
                </div>
            </div>
            <div class="arch-card">
                <h3>Why Structured Output?</h3>
                <ul style="margin:0.4rem 0;padding-left:1.2rem;">
                    <li><strong>Schema-agnostic</strong> β€” learns patterns, not column names</li>
                    <li><strong>Always valid SQL</strong> β€” deterministic decoder guarantees syntax</li>
                    <li><strong>Smaller search space</strong> β€” predicts indices, not full strings</li>
                    <li><strong>Interpretable</strong> β€” each component inspectable independently</li>
                </ul>
            </div>
            <div class="arch-card">
                <h3>Encoding Reference</h3>
                <table class="encoding-table">
                    <tr><th>Component</th><th>Index</th><th>Meaning</th></tr>
                    <tr><td><strong>AGG</strong></td><td class="mono">0</td><td>No aggregation</td></tr>
                    <tr><td></td><td class="mono">1</td><td>MAX</td></tr>
                    <tr><td></td><td class="mono">2</td><td>MIN</td></tr>
                    <tr><td></td><td class="mono">3</td><td>COUNT</td></tr>
                    <tr><td></td><td class="mono">4</td><td>SUM</td></tr>
                    <tr><td></td><td class="mono">5</td><td>AVG</td></tr>
                    <tr><td><strong>OP</strong></td><td class="mono">0</td><td>= (equals)</td></tr>
                    <tr><td></td><td class="mono">1</td><td>> (greater than)</td></tr>
                    <tr><td></td><td class="mono">2</td><td>< (less than)</td></tr>
                    <tr><td></td><td class="mono">3</td><td>>= (greater or equal)</td></tr>
                    <tr><td></td><td class="mono">4</td><td><= (less or equal)</td></tr>
                    <tr><td></td><td class="mono">5</td><td>!= (not equal)</td></tr>
                </table>
            </div>
            """)

        # ══ TAB 3: MODEL INFO ══
        with gr.Tab("Model & Training"):
            gr.HTML("""
            <div class="stats-grid">
                <div class="stat-card">
                    <div class="stat-value">220M</div>
                    <div class="stat-label">Parameters</div>
                </div>
                <div class="stat-card">
                    <div class="stat-value">80K+</div>
                    <div class="stat-label">Training Examples</div>
                </div>
                <div class="stat-card">
                    <div class="stat-value">T5-base</div>
                    <div class="stat-label">Architecture</div>
                </div>
                <div class="stat-card">
                    <div class="stat-value">WikiSQL</div>
                    <div class="stat-label">Dataset</div>
                </div>
            </div>
            <div class="arch-grid">
                <div class="arch-card">
                    <h3>Model</h3>
                    <ul style="margin:0.4rem 0;padding-left:1.2rem;">
                        <li><strong>Base:</strong> T5-base (encoder-decoder)</li>
                        <li><strong>Tokenizer:</strong> SentencePiece (32K vocab)</li>
                        <li><strong>Max input:</strong> 512 tokens</li>
                        <li><strong>Max output:</strong> 256 tokens</li>
                        <li><strong>Decoding:</strong> Beam search (5 beams)</li>
                        <li><strong>Framework:</strong> Transformers + PyTorch</li>
                    </ul>
                </div>
                <div class="arch-card">
                    <h3>Training</h3>
                    <ul style="margin:0.4rem 0;padding-left:1.2rem;">
                        <li><strong>Dataset:</strong> WikiSQL (Zhong et al., 2017)</li>
                        <li><strong>Train:</strong> ~56,355 examples</li>
                        <li><strong>Dev:</strong> ~8,421 examples</li>
                        <li><strong>Test:</strong> ~15,878 examples</li>
                        <li><strong>Output:</strong> Structured tokens (SEL/AGG/CONDS)</li>
                        <li><strong>Prefix:</strong> <code>translate to SQL:</code></li>
                    </ul>
                </div>
                <div class="arch-card">
                    <h3>WikiSQL Dataset</h3>
                    <p>80,654 hand-annotated SQL queries across 24,241 Wikipedia tables.
                    Single-table queries with SELECT, aggregation, and WHERE conditions.</p>
                    <p style="margin-top:0.4rem;"><a href="https://github.com/salesforce/WikiSQL" target="_blank">github.com/salesforce/WikiSQL</a></p>
                </div>
                <div class="arch-card">
                    <h3>Limitations</h3>
                    <ul style="margin:0.4rem 0;padding-left:1.2rem;">
                        <li><strong>Single-table only</strong> β€” no JOINs or subqueries</li>
                        <li><strong>Fixed operators</strong> β€” =, >, <, >=, <=, !=</li>
                        <li><strong>No GROUP BY / ORDER BY</strong></li>
                        <li><strong>AND-only</strong> conditions</li>
                        <li><strong>Schema required</strong> as input</li>
                    </ul>
                </div>
            </div>
            """)

    gr.HTML("""
    <div class="app-footer">
        <a href="https://huggingface.co/RealMati/t2sql_v6_structured" target="_blank">Model</a>
        &nbsp;&bull;&nbsp;
        <a href="https://github.com/salesforce/WikiSQL" target="_blank">WikiSQL</a>
        &nbsp;&bull;&nbsp;
        Built with Transformers &amp; Gradio
    </div>
    """)

demo.launch(theme=theme, css=CSS)