RealMati commited on
Commit
9442d19
·
verified ·
1 Parent(s): 4713276

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +280 -0
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import gradio as gr
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ import torch
6
+ import time
7
+
8
+ MODEL_ID = "RealMati/text2sql-wikisql-v5"
9
+
10
+ print(f"Loading model: {MODEL_ID}")
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
13
+ model.eval()
14
+ print("Model loaded.")
15
+
16
+ css_path = os.path.join(os.path.dirname(__file__), "style.css")
17
+ with open(css_path, "r") as f:
18
+ CSS = f.read()
19
+
20
+ SQL_KEYWORDS = [
21
+ "SELECT", "FROM", "WHERE", "AND", "OR", "NOT", "IN", "LIKE",
22
+ "JOIN", "LEFT", "RIGHT", "INNER", "OUTER", "ON", "AS",
23
+ "GROUP", "BY", "ORDER", "HAVING", "LIMIT", "OFFSET",
24
+ "DISTINCT", "COUNT", "SUM", "AVG", "MIN", "MAX",
25
+ "BETWEEN", "EXISTS", "UNION", "ALL", "ANY", "CASE",
26
+ "WHEN", "THEN", "ELSE", "END", "IS", "NULL", "ASC", "DESC",
27
+ ]
28
+
29
+
30
+ def postprocess_sql(sql):
31
+ sql = sql.strip()
32
+ sql = re.sub(r"<pad>|<unk>|<s>|</s>", "", sql)
33
+ sql = re.sub(r"\s+", " ", sql)
34
+ for kw in SQL_KEYWORDS:
35
+ sql = re.sub(rf"\b{re.escape(kw.lower())}\b", kw, sql, flags=re.IGNORECASE)
36
+ return sql.strip()
37
+
38
+
39
+ def predict(question, schema, num_beams, max_length):
40
+ if not question or not question.strip():
41
+ return (
42
+ "-- Enter a question and schema, then click Generate SQL",
43
+ "",
44
+ )
45
+
46
+ input_text = f"translate to SQL: {question}"
47
+ if schema and schema.strip():
48
+ input_text += f" | schema: {schema.strip()}"
49
+
50
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
51
+
52
+ t0 = time.time()
53
+ with torch.no_grad():
54
+ outputs = model.generate(
55
+ **inputs,
56
+ max_length=int(max_length),
57
+ num_beams=int(num_beams),
58
+ early_stopping=True,
59
+ do_sample=False,
60
+ )
61
+ latency = time.time() - t0
62
+
63
+ raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+ sql = postprocess_sql(raw)
65
+ perf = f"Inference: {latency:.2f}s | Beams: {int(num_beams)} | Tokens: {inputs['input_ids'].shape[1]}"
66
+
67
+ return sql, perf
68
+
69
+
70
+ theme = gr.themes.Base(
71
+ primary_hue="blue",
72
+ secondary_hue="purple",
73
+ neutral_hue="gray",
74
+ font=gr.themes.GoogleFont("Inter"),
75
+ font_mono=gr.themes.GoogleFont("Fira Code"),
76
+ ).set(
77
+ body_background_fill="#0d1117",
78
+ body_text_color="#e2e8f0",
79
+ block_background_fill="#161b22",
80
+ block_border_color="#1f2937",
81
+ block_border_width="1px",
82
+ block_label_text_color="#d1d5db",
83
+ block_title_text_color="#f3f4f6",
84
+ block_radius="12px",
85
+ block_shadow="none",
86
+ input_background_fill="#111827",
87
+ input_border_color="#1f2937",
88
+ input_border_width="1px",
89
+ input_placeholder_color="#4b5563",
90
+ input_radius="8px",
91
+ slider_color="#3b82f6",
92
+ button_primary_background_fill="linear-gradient(135deg, #3b82f6, #8b5cf6)",
93
+ button_primary_text_color="#ffffff",
94
+ button_secondary_background_fill="#111827",
95
+ button_secondary_text_color="#d1d5db",
96
+ button_secondary_border_color="#1f2937",
97
+ border_color_primary="#1f2937",
98
+ color_accent_soft="#111827",
99
+ )
100
+
101
+ with gr.Blocks(title="Text-to-SQL V5 | Direct SQL Output") as demo:
102
+
103
+ gr.HTML("""
104
+ <div class="app-header">
105
+ <h1><span>Text-to-SQL</span> <small>V5</small></h1>
106
+ </div>
107
+ <div class="tech-badges">
108
+ <span class="badge badge-indigo">T5-base (220M)</span>
109
+ <span class="badge badge-purple">Seq2Seq</span>
110
+ <span class="badge badge-emerald">WikiSQL 80K+</span>
111
+ <span class="badge badge-amber">Direct SQL Output</span>
112
+ </div>
113
+ <div class="pipeline-strip">
114
+ <span class="step step-input">Question</span>
115
+ <span class="arrow">&rarr;</span>
116
+ <span class="step step-model">T5 Encoder-Decoder</span>
117
+ <span class="arrow">&rarr;</span>
118
+ <span class="step step-sql">SQL Query</span>
119
+ </div>
120
+ """)
121
+
122
+ with gr.Tabs():
123
+
124
+ with gr.Tab("Demo"):
125
+ with gr.Row(equal_height=False):
126
+ with gr.Column(scale=1):
127
+ question = gr.Textbox(
128
+ label="Natural Language Question",
129
+ placeholder="e.g. What is terrence ross' nationality?",
130
+ lines=2,
131
+ )
132
+ schema = gr.Textbox(
133
+ label="Database Schema (optional)",
134
+ placeholder="table_name: col1, col2, col3, ...",
135
+ lines=2,
136
+ )
137
+ gr.HTML('<p class="input-hint">Format: <code>table: col1, col2, col3</code></p>')
138
+ with gr.Row():
139
+ beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Beam Size")
140
+ max_len = gr.Slider(minimum=64, maximum=512, value=256, step=64, label="Max Length")
141
+ btn = gr.Button("Generate SQL", variant="primary", elem_classes=["generate-btn"], size="lg")
142
+
143
+ with gr.Column(scale=1):
144
+ sql_out = gr.Textbox(
145
+ label="Generated SQL",
146
+ value="-- Enter a question, then click Generate SQL",
147
+ lines=4,
148
+ elem_classes=["sql-output"],
149
+ )
150
+ latency_out = gr.Textbox(label="Performance", value="", lines=1, elem_classes=["decode-box"])
151
+
152
+ btn.click(fn=predict, inputs=[question, schema, beams, max_len], outputs=[sql_out, latency_out])
153
+ question.submit(fn=predict, inputs=[question, schema, beams, max_len], outputs=[sql_out, latency_out])
154
+
155
+ gr.Markdown("#### Examples")
156
+ gr.Examples(
157
+ examples=[
158
+ ["What is terrence ross' nationality", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
159
+ ["how many schools or teams had jalen rose", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
160
+ ["What was the date of the race in Misano?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
161
+ ["What was the number of race that Kevin Curtain won?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
162
+ ["Where was Assen held?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
163
+ ["How many different positions did Sherbrooke Faucons (qmjhl) provide in the draft?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
164
+ ["What are the nationalities of the player picked from Thunder Bay Flyers (ushl)", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
165
+ ["How many different nationalities do the players of New Jersey Devils come from?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
166
+ ["What's Dorain Anneck's pick number?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
167
+ ],
168
+ inputs=[question, schema, beams, max_len],
169
+ outputs=[sql_out, latency_out],
170
+ fn=predict,
171
+ cache_examples=False,
172
+ )
173
+
174
+ with gr.Tab("How It Works"):
175
+ gr.HTML("""
176
+ <div class="arch-card">
177
+ <h3>Architecture</h3>
178
+ <p>A <strong>T5-base</strong> encoder-decoder fine-tuned on WikiSQL.
179
+ This version generates <strong>SQL directly</strong> as free-form text,
180
+ unlike V6 which outputs structured tokens. The model learns to produce
181
+ syntactically correct SQL from natural language questions.</p>
182
+ </div>
183
+ <div class="arch-grid">
184
+ <div class="arch-card">
185
+ <h3>Input Format</h3>
186
+ <p>Question and optional schema concatenated:</p>
187
+ <p><code>translate to SQL: {question} | schema: {table}: {col1}, {col2}</code></p>
188
+ </div>
189
+ <div class="arch-card">
190
+ <h3>Output Format</h3>
191
+ <p>The model directly outputs SQL text:</p>
192
+ <p><code>SELECT Nationality FROM players WHERE Player = 'Terrence Ross'</code></p>
193
+ <p>Post-processing normalizes whitespace and uppercases SQL keywords.</p>
194
+ </div>
195
+ </div>
196
+ <div class="arch-card">
197
+ <h3>V5 vs V6 Comparison</h3>
198
+ <table class="encoding-table">
199
+ <tr><th>Aspect</th><th>V5 (Direct SQL)</th><th>V6 (Structured)</th></tr>
200
+ <tr><td>Output</td><td>Raw SQL string</td><td>SEL/AGG/CONDS tokens</td></tr>
201
+ <tr><td>Schema dependency</td><td>Optional</td><td>Required (for index mapping)</td></tr>
202
+ <tr><td>Flexibility</td><td>Can produce any SQL</td><td>Limited to WikiSQL operations</td></tr>
203
+ <tr><td>Reliability</td><td>May produce invalid SQL</td><td>Guaranteed valid structure</td></tr>
204
+ <tr><td>Generalization</td><td>Memorizes column names</td><td>Schema-agnostic indices</td></tr>
205
+ </table>
206
+ </div>
207
+ """)
208
+
209
+ with gr.Tab("Model & Training"):
210
+ gr.HTML("""
211
+ <div class="stats-grid">
212
+ <div class="stat-card">
213
+ <div class="stat-value">220M</div>
214
+ <div class="stat-label">Parameters</div>
215
+ </div>
216
+ <div class="stat-card">
217
+ <div class="stat-value">80K+</div>
218
+ <div class="stat-label">Training Examples</div>
219
+ </div>
220
+ <div class="stat-card">
221
+ <div class="stat-value">T5-base</div>
222
+ <div class="stat-label">Architecture</div>
223
+ </div>
224
+ <div class="stat-card">
225
+ <div class="stat-value">WikiSQL</div>
226
+ <div class="stat-label">Dataset</div>
227
+ </div>
228
+ </div>
229
+ <div class="arch-grid">
230
+ <div class="arch-card">
231
+ <h3>Model</h3>
232
+ <ul style="margin:0.4rem 0;padding-left:1.2rem;">
233
+ <li><strong>Base:</strong> T5-base (encoder-decoder)</li>
234
+ <li><strong>Tokenizer:</strong> SentencePiece (32K vocab)</li>
235
+ <li><strong>Max input:</strong> 512 tokens</li>
236
+ <li><strong>Max output:</strong> 256 tokens</li>
237
+ <li><strong>Decoding:</strong> Beam search (5 beams)</li>
238
+ <li><strong>Framework:</strong> Transformers + PyTorch</li>
239
+ </ul>
240
+ </div>
241
+ <div class="arch-card">
242
+ <h3>Training</h3>
243
+ <ul style="margin:0.4rem 0;padding-left:1.2rem;">
244
+ <li><strong>Dataset:</strong> WikiSQL (Zhong et al., 2017)</li>
245
+ <li><strong>Train:</strong> ~56,355 examples</li>
246
+ <li><strong>Dev:</strong> ~8,421 examples</li>
247
+ <li><strong>Test:</strong> ~15,878 examples</li>
248
+ <li><strong>Output:</strong> Direct SQL strings</li>
249
+ <li><strong>Prefix:</strong> <code>translate to SQL:</code></li>
250
+ </ul>
251
+ </div>
252
+ <div class="arch-card">
253
+ <h3>WikiSQL Dataset</h3>
254
+ <p>80,654 hand-annotated SQL queries across 24,241 Wikipedia tables.
255
+ Single-table queries with SELECT, aggregation, and WHERE conditions.</p>
256
+ <p style="margin-top:0.4rem;"><a href="https://github.com/salesforce/WikiSQL" target="_blank">github.com/salesforce/WikiSQL</a></p>
257
+ </div>
258
+ <div class="arch-card">
259
+ <h3>Limitations</h3>
260
+ <ul style="margin:0.4rem 0;padding-left:1.2rem;">
261
+ <li><strong>Single-table only</strong> — no JOINs or subqueries</li>
262
+ <li><strong>May hallucinate</strong> column names not in schema</li>
263
+ <li><strong>No syntax guarantee</strong> — free-form output can be invalid</li>
264
+ <li><strong>AND-only</strong> conditions</li>
265
+ </ul>
266
+ </div>
267
+ </div>
268
+ """)
269
+
270
+ gr.HTML("""
271
+ <div class="app-footer">
272
+ <a href="https://huggingface.co/RealMati/text2sql-wikisql-v5" target="_blank">Model</a>
273
+ &nbsp;&bull;&nbsp;
274
+ <a href="https://github.com/salesforce/WikiSQL" target="_blank">WikiSQL</a>
275
+ &nbsp;&bull;&nbsp;
276
+ Built with Transformers &amp; Gradio
277
+ </div>
278
+ """)
279
+
280
+ demo.launch(theme=theme, css=CSS)