artmiss commited on
Commit
b08521f
·
verified ·
1 Parent(s): 3a1a24d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +484 -0
app.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from peft import PeftModel
5
+
6
+ # --- Model Loading ---
7
+
8
+ BASE_MODEL = "google/flan-t5-large"
9
+ ADAPTER_MODEL = "artmiss/flan-t5-large-spider-text2sql"
10
+
11
+ print("Loading model...")
12
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL)
13
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
14
+ model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
15
+ model.eval()
16
+ print("Model loaded.")
17
+
18
+ # --- Inference ---
19
+
20
+ def predict(question: str, schema: str) -> str:
21
+ if not question.strip():
22
+ return "⚠️ Please enter a question."
23
+ if not schema.strip():
24
+ return "⚠️ Please add at least one table to the schema."
25
+
26
+ input_text = f"Translate English to SQL: {question} | Schemas: {schema}"
27
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
28
+
29
+ with torch.inference_mode():
30
+ outputs = model.generate(
31
+ **inputs,
32
+ max_length=128,
33
+ num_beams=4,
34
+ )
35
+
36
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+
38
+
39
+ # --- Schema Builder Logic ---
40
+
41
+ def build_schema(tables_state):
42
+ """Convert the tables state dict into schema string: table(col1, col2) | table2(col3)"""
43
+ parts = []
44
+ for table_name, columns in tables_state.items():
45
+ if table_name.strip():
46
+ cols = [c.strip() for c in columns if c.strip()]
47
+ if cols:
48
+ parts.append(f"{table_name.strip()}({', '.join(cols)})")
49
+ return " | ".join(parts)
50
+
51
+
52
+ def add_table(tables_state, new_table_name):
53
+ name = new_table_name.strip()
54
+ if not name:
55
+ return tables_state, gr.update(), format_schema_display(tables_state), "⚠️ Table name cannot be empty."
56
+ if name in tables_state:
57
+ return tables_state, gr.update(), format_schema_display(tables_state), f"⚠️ Table '{name}' already exists."
58
+ tables_state[name] = []
59
+ return tables_state, gr.update(value=""), format_schema_display(tables_state), f"✅ Table '{name}' added."
60
+
61
+
62
+ def add_column(tables_state, selected_table, new_col_name):
63
+ col = new_col_name.strip()
64
+ if not selected_table:
65
+ return tables_state, gr.update(), format_schema_display(tables_state), "⚠️ Select a table first."
66
+ if not col:
67
+ return tables_state, gr.update(), format_schema_display(tables_state), "⚠️ Column name cannot be empty."
68
+ if col in tables_state.get(selected_table, []):
69
+ return tables_state, gr.update(), format_schema_display(tables_state), f"⚠️ Column '{col}' already exists in '{selected_table}'."
70
+ tables_state[selected_table].append(col)
71
+ return tables_state, gr.update(value=""), format_schema_display(tables_state), f"✅ Column '{col}' added to '{selected_table}'."
72
+
73
+
74
+ def remove_table(tables_state, selected_table):
75
+ if not selected_table or selected_table not in tables_state:
76
+ return tables_state, gr.update(choices=list(tables_state.keys()), value=None), format_schema_display(tables_state), "⚠️ Select a table to remove."
77
+ del tables_state[selected_table]
78
+ choices = list(tables_state.keys())
79
+ return tables_state, gr.update(choices=choices, value=choices[0] if choices else None), format_schema_display(tables_state), f"🗑️ Table '{selected_table}' removed."
80
+
81
+
82
+ def update_table_dropdown(tables_state):
83
+ return gr.update(choices=list(tables_state.keys()), value=list(tables_state.keys())[0] if tables_state else None)
84
+
85
+
86
+ def format_schema_display(tables_state):
87
+ if not tables_state:
88
+ return "_No tables added yet._"
89
+ lines = []
90
+ for table, cols in tables_state.items():
91
+ col_str = ", ".join(cols) if cols else "_no columns_"
92
+ lines.append(f"**{table}** ( {col_str} )")
93
+ return "\n\n".join(lines)
94
+
95
+
96
+ def run_prediction(question, tables_state):
97
+ schema = build_schema(tables_state)
98
+ sql = predict(question, schema)
99
+ return sql
100
+
101
+
102
+ def load_example(example, tables_state):
103
+ question = example[0]
104
+ schema_str = example[1]
105
+ # Parse schema string back into tables_state
106
+ new_state = {}
107
+ for part in schema_str.split(" | "):
108
+ if "(" in part and part.endswith(")"):
109
+ table_name = part[:part.index("(")].strip()
110
+ cols_str = part[part.index("(")+1:-1]
111
+ cols = [c.strip() for c in cols_str.split(",") if c.strip()]
112
+ new_state[table_name] = cols
113
+ return question, new_state, gr.update(choices=list(new_state.keys()), value=list(new_state.keys())[0] if new_state else None), format_schema_display(new_state)
114
+
115
+
116
+ # --- Examples ---
117
+
118
+ EXAMPLES = [
119
+ [
120
+ "How many players are from each country?",
121
+ "players(player_id, first_name, last_name, country_code, birth_date)",
122
+ ],
123
+ [
124
+ "Who are the top 3 highest paid employees?",
125
+ "employees(employee_id, name, age, salary, department_id)",
126
+ ],
127
+ [
128
+ "What are the names of customers who placed an order?",
129
+ "customers(customer_id, name, email, country) | orders(order_id, customer_id, total, date)",
130
+ ],
131
+ [
132
+ "What is the average salary of employees in each department?",
133
+ "employees(employee_id, name, salary, department_id) | departments(department_id, name, location)",
134
+ ],
135
+ [
136
+ "Which products cost more than 100?",
137
+ "products(product_id, name, price, category, stock)",
138
+ ],
139
+ ]
140
+
141
+ # --- UI ---
142
+
143
+ CSS = """
144
+ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=Syne:wght@400;600;800&display=swap');
145
+
146
+ :root {
147
+ --bg: #0d0f14;
148
+ --surface: #13161d;
149
+ --surface2: #1a1e28;
150
+ --border: #252a38;
151
+ --accent: #4fffb0;
152
+ --accent2: #4d9eff;
153
+ --text: #e2e8f0;
154
+ --muted: #64748b;
155
+ --sql-bg: #0a0c10;
156
+ }
157
+
158
+ body, .gradio-container {
159
+ background: var(--bg) !important;
160
+ font-family: 'Syne', sans-serif !important;
161
+ color: var(--text) !important;
162
+ }
163
+
164
+ /* Header */
165
+ .app-header {
166
+ text-align: center;
167
+ padding: 2.5rem 1rem 1.5rem;
168
+ border-bottom: 1px solid var(--border);
169
+ margin-bottom: 2rem;
170
+ }
171
+
172
+ .app-title {
173
+ font-size: 2.4rem;
174
+ font-weight: 800;
175
+ letter-spacing: -0.03em;
176
+ color: var(--accent);
177
+ margin: 0;
178
+ line-height: 1.1;
179
+ }
180
+
181
+ .app-subtitle {
182
+ color: var(--muted);
183
+ font-size: 0.95rem;
184
+ margin-top: 0.5rem;
185
+ font-family: 'JetBrains Mono', monospace;
186
+ }
187
+
188
+ /* Panels */
189
+ .panel {
190
+ background: var(--surface);
191
+ border: 1px solid var(--border);
192
+ border-radius: 12px;
193
+ padding: 1.25rem;
194
+ margin-bottom: 1rem;
195
+ }
196
+
197
+ .panel-title {
198
+ font-size: 0.7rem;
199
+ font-weight: 600;
200
+ letter-spacing: 0.12em;
201
+ text-transform: uppercase;
202
+ color: var(--muted);
203
+ margin-bottom: 0.75rem;
204
+ }
205
+
206
+ /* Inputs */
207
+ input[type="text"], textarea {
208
+ background: var(--surface2) !important;
209
+ border: 1px solid var(--border) !important;
210
+ border-radius: 8px !important;
211
+ color: var(--text) !important;
212
+ font-family: 'JetBrains Mono', monospace !important;
213
+ font-size: 0.88rem !important;
214
+ }
215
+
216
+ input[type="text"]:focus, textarea:focus {
217
+ border-color: var(--accent) !important;
218
+ box-shadow: 0 0 0 2px rgba(79, 255, 176, 0.1) !important;
219
+ outline: none !important;
220
+ }
221
+
222
+ /* Buttons */
223
+ button.primary-btn {
224
+ background: var(--accent) !important;
225
+ color: #0d0f14 !important;
226
+ font-family: 'Syne', sans-serif !important;
227
+ font-weight: 700 !important;
228
+ font-size: 0.95rem !important;
229
+ border-radius: 8px !important;
230
+ border: none !important;
231
+ padding: 0.65rem 1.5rem !important;
232
+ cursor: pointer !important;
233
+ transition: opacity 0.15s !important;
234
+ letter-spacing: 0.02em !important;
235
+ }
236
+
237
+ button.primary-btn:hover {
238
+ opacity: 0.85 !important;
239
+ }
240
+
241
+ button.secondary-btn {
242
+ background: var(--surface2) !important;
243
+ color: var(--text) !important;
244
+ font-family: 'Syne', sans-serif !important;
245
+ font-size: 0.85rem !important;
246
+ border: 1px solid var(--border) !important;
247
+ border-radius: 6px !important;
248
+ padding: 0.5rem 1rem !important;
249
+ cursor: pointer !important;
250
+ transition: border-color 0.15s !important;
251
+ }
252
+
253
+ button.secondary-btn:hover {
254
+ border-color: var(--accent2) !important;
255
+ }
256
+
257
+ button.danger-btn {
258
+ background: transparent !important;
259
+ color: #f87171 !important;
260
+ font-family: 'Syne', sans-serif !important;
261
+ font-size: 0.85rem !important;
262
+ border: 1px solid #3d1f1f !important;
263
+ border-radius: 6px !important;
264
+ padding: 0.5rem 1rem !important;
265
+ cursor: pointer !important;
266
+ }
267
+
268
+ /* SQL output */
269
+ .sql-output {
270
+ background: var(--sql-bg) !important;
271
+ border: 1px solid var(--border) !important;
272
+ border-left: 3px solid var(--accent) !important;
273
+ border-radius: 8px !important;
274
+ padding: 1.1rem 1.25rem !important;
275
+ font-family: 'JetBrains Mono', monospace !important;
276
+ font-size: 0.92rem !important;
277
+ color: var(--accent) !important;
278
+ min-height: 60px;
279
+ white-space: pre-wrap;
280
+ word-break: break-word;
281
+ }
282
+
283
+ /* Schema display */
284
+ .schema-display {
285
+ background: var(--sql-bg) !important;
286
+ border: 1px solid var(--border) !important;
287
+ border-radius: 8px !important;
288
+ padding: 1rem !important;
289
+ font-family: 'JetBrains Mono', monospace !important;
290
+ font-size: 0.82rem !important;
291
+ color: var(--accent2) !important;
292
+ min-height: 80px;
293
+ }
294
+
295
+ /* Status messages */
296
+ .status-msg {
297
+ font-family: 'JetBrains Mono', monospace;
298
+ font-size: 0.8rem;
299
+ color: var(--muted);
300
+ min-height: 1.2rem;
301
+ padding: 0.25rem 0;
302
+ }
303
+
304
+ /* Examples */
305
+ .example-btn {
306
+ background: var(--surface2) !important;
307
+ border: 1px solid var(--border) !important;
308
+ border-radius: 6px !important;
309
+ color: var(--muted) !important;
310
+ font-family: 'JetBrains Mono', monospace !important;
311
+ font-size: 0.78rem !important;
312
+ padding: 0.4rem 0.75rem !important;
313
+ cursor: pointer !important;
314
+ transition: all 0.15s !important;
315
+ text-align: left !important;
316
+ }
317
+
318
+ .example-btn:hover {
319
+ border-color: var(--accent2) !important;
320
+ color: var(--text) !important;
321
+ }
322
+
323
+ /* Dropdown */
324
+ select {
325
+ background: var(--surface2) !important;
326
+ border: 1px solid var(--border) !important;
327
+ color: var(--text) !important;
328
+ border-radius: 6px !important;
329
+ font-family: 'JetBrains Mono', monospace !important;
330
+ font-size: 0.85rem !important;
331
+ }
332
+
333
+ /* Labels */
334
+ label span {
335
+ color: var(--muted) !important;
336
+ font-size: 0.78rem !important;
337
+ font-weight: 600 !important;
338
+ letter-spacing: 0.08em !important;
339
+ text-transform: uppercase !important;
340
+ font-family: 'Syne', sans-serif !important;
341
+ }
342
+
343
+ /* Markdown */
344
+ .schema-display p, .schema-display strong {
345
+ color: var(--accent2) !important;
346
+ font-family: 'JetBrains Mono', monospace !important;
347
+ }
348
+
349
+ /* Hide gradio footer */
350
+ footer { display: none !important; }
351
+ """
352
+
353
+ with gr.Blocks(css=CSS, title="Text-to-SQL") as demo:
354
+
355
+ tables_state = gr.State({})
356
+
357
+ # Header
358
+ gr.HTML("""
359
+ <div class="app-header">
360
+ <h1 class="app-title">Text → SQL</h1>
361
+ <p class="app-subtitle">flan-t5-large · LoRA · Spider benchmark</p>
362
+ </div>
363
+ """)
364
+
365
+ with gr.Row():
366
+
367
+ # Left column — Schema Builder
368
+ with gr.Column(scale=1):
369
+ gr.HTML('<div class="panel-title">Schema Builder</div>')
370
+
371
+ with gr.Group():
372
+ new_table_input = gr.Textbox(
373
+ placeholder="e.g. players",
374
+ label="Table name",
375
+ lines=1,
376
+ )
377
+ add_table_btn = gr.Button("+ Add Table", elem_classes=["secondary-btn"])
378
+
379
+ with gr.Group():
380
+ table_dropdown = gr.Dropdown(
381
+ choices=[],
382
+ label="Select table",
383
+ interactive=True,
384
+ )
385
+ new_col_input = gr.Textbox(
386
+ placeholder="e.g. player_id",
387
+ label="Column name",
388
+ lines=1,
389
+ )
390
+ with gr.Row():
391
+ add_col_btn = gr.Button("+ Add Column", elem_classes=["secondary-btn"])
392
+ remove_table_btn = gr.Button("Remove Table", elem_classes=["danger-btn"])
393
+
394
+ gr.HTML('<div class="panel-title" style="margin-top:1rem">Current Schema</div>')
395
+ schema_display = gr.Markdown(
396
+ value="_No tables added yet._",
397
+ elem_classes=["schema-display"],
398
+ )
399
+ status_msg = gr.Markdown(value="", elem_classes=["status-msg"])
400
+
401
+ # Right column — Question + Output
402
+ with gr.Column(scale=1):
403
+ gr.HTML('<div class="panel-title">Question</div>')
404
+ question_input = gr.Textbox(
405
+ placeholder="e.g. How many players are from each country?",
406
+ label="Natural language question",
407
+ lines=3,
408
+ )
409
+ generate_btn = gr.Button("Generate SQL →", elem_classes=["primary-btn"])
410
+
411
+ gr.HTML('<div class="panel-title" style="margin-top:1.5rem">Generated SQL</div>')
412
+ sql_output = gr.Code(
413
+ label="",
414
+ language="sql",
415
+ lines=5,
416
+ interactive=False,
417
+ )
418
+
419
+ # Examples
420
+ gr.HTML('<div class="panel-title" style="margin-top:1.5rem">Examples</div>')
421
+ with gr.Row():
422
+ for ex in EXAMPLES:
423
+ ex_btn = gr.Button(ex[0][:45] + ("…" if len(ex[0]) > 45 else ""), elem_classes=["example-btn"])
424
+ ex_btn.click(
425
+ fn=lambda e=ex: load_example(e, {}),
426
+ inputs=[],
427
+ outputs=[question_input, tables_state, table_dropdown, schema_display],
428
+ )
429
+
430
+ # --- Event wiring ---
431
+
432
+ add_table_btn.click(
433
+ fn=add_table,
434
+ inputs=[tables_state, new_table_input],
435
+ outputs=[tables_state, new_table_input, schema_display, status_msg],
436
+ ).then(
437
+ fn=update_table_dropdown,
438
+ inputs=[tables_state],
439
+ outputs=[table_dropdown],
440
+ )
441
+
442
+ add_col_btn.click(
443
+ fn=add_column,
444
+ inputs=[tables_state, table_dropdown, new_col_input],
445
+ outputs=[tables_state, new_col_input, schema_display, status_msg],
446
+ )
447
+
448
+ remove_table_btn.click(
449
+ fn=remove_table,
450
+ inputs=[tables_state, table_dropdown],
451
+ outputs=[tables_state, table_dropdown, schema_display, status_msg],
452
+ )
453
+
454
+ generate_btn.click(
455
+ fn=run_prediction,
456
+ inputs=[question_input, tables_state],
457
+ outputs=[sql_output],
458
+ )
459
+
460
+ new_table_input.submit(
461
+ fn=add_table,
462
+ inputs=[tables_state, new_table_input],
463
+ outputs=[tables_state, new_table_input, schema_display, status_msg],
464
+ ).then(
465
+ fn=update_table_dropdown,
466
+ inputs=[tables_state],
467
+ outputs=[table_dropdown],
468
+ )
469
+
470
+ new_col_input.submit(
471
+ fn=add_column,
472
+ inputs=[tables_state, table_dropdown, new_col_input],
473
+ outputs=[tables_state, new_col_input, schema_display, status_msg],
474
+ )
475
+
476
+ question_input.submit(
477
+ fn=run_prediction,
478
+ inputs=[question_input, tables_state],
479
+ outputs=[sql_output],
480
+ )
481
+
482
+
483
+ if __name__ == "__main__":
484
+ demo.launch()