RealMati commited on
Commit
06e0188
·
verified ·
1 Parent(s): e281b6d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +232 -67
app.py CHANGED
@@ -14,9 +14,88 @@ print("Model loaded.")
14
  AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
15
  OPS = ["=", ">", "<", ">=", "<=", "!="]
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def decode_structured_output(text):
19
- """Parse model output like 'SEL:0 | AGG:0 | CONDS:3,1,18' into components."""
20
  sel = agg = None
21
  conds = []
22
  try:
@@ -30,7 +109,7 @@ def decode_structured_output(text):
30
  cond_str = part[6:].strip()
31
  if cond_str:
32
  for c in cond_str.split(";"):
33
- vals = c.split(",", 2) # split max 2 times so value can contain commas
34
  if len(vals) >= 3:
35
  conds.append([int(vals[0]), int(vals[1]), vals[2]])
36
  except Exception:
@@ -39,47 +118,34 @@ def decode_structured_output(text):
39
 
40
 
41
  def parse_schema(schema_str):
42
- """Parse schema string like 'users: id, name, age, email' into (table_name, [columns])."""
43
  schema_str = schema_str.strip()
44
  if not schema_str:
45
  return "table", []
46
-
47
- # Take only the first table for now (WikiSQL is single-table)
48
  first_table = schema_str.split("|")[0].strip()
49
-
50
  if ":" in first_table:
51
  table_name, cols_str = first_table.split(":", 1)
52
  table_name = table_name.strip()
53
  columns = [c.strip() for c in cols_str.split(",") if c.strip()]
54
  else:
55
- # No table name, just columns
56
  table_name = "table"
57
  columns = [c.strip() for c in first_table.split(",") if c.strip()]
58
-
59
  return table_name, columns
60
 
61
 
62
  def structured_to_sql(sel, agg, conds, columns, table_name="table"):
63
- """Convert structured indices to a SQL query string."""
64
  if sel is None or agg is None:
65
  return None
66
-
67
  col_name = columns[sel] if sel < len(columns) else f"col{sel}"
68
-
69
- # SELECT clause
70
  if agg == 0:
71
  sql = f"SELECT {col_name} FROM {table_name}"
72
  else:
73
  agg_op = AGG_OPS[agg] if agg < len(AGG_OPS) else ""
74
  sql = f"SELECT {agg_op}({col_name}) FROM {table_name}"
75
-
76
- # WHERE clause
77
  if conds:
78
  where_parts = []
79
  for c_idx, c_op, c_val in conds:
80
  c_name = columns[c_idx] if c_idx < len(columns) else f"col{c_idx}"
81
  op_str = OPS[c_op] if c_op < len(OPS) else "="
82
- # Quote string values
83
  try:
84
  float(c_val)
85
  val_sql = c_val
@@ -88,16 +154,35 @@ def structured_to_sql(sel, agg, conds, columns, table_name="table"):
88
  where_parts.append(f"{c_name} {op_str} {val_sql}")
89
  if where_parts:
90
  sql += " WHERE " + " AND ".join(where_parts)
91
-
92
  return sql
93
 
94
 
95
- def predict(question: str, schema: str, num_beams: int, max_length: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if not question.strip():
97
- return "Please enter a question.", "", ""
98
 
99
  table_name, columns = parse_schema(schema)
100
-
101
  input_text = f"translate to SQL: {question}"
102
  if schema.strip():
103
  input_text += f" | schema: {schema.strip()}"
@@ -114,59 +199,139 @@ def predict(question: str, schema: str, num_beams: int, max_length: int):
114
  )
115
 
116
  raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
-
118
- # Parse structured output
119
  sel, agg, conds = decode_structured_output(raw_output)
120
 
121
  if sel is not None and agg is not None and columns:
122
  sql = structured_to_sql(sel, agg, conds, columns, table_name)
123
  else:
124
- sql = f"(could not convert no schema columns provided)"
125
-
126
- return sql, raw_output, f"SEL={sel}, AGG={agg} ({AGG_OPS[agg] if agg and agg < len(AGG_OPS) else 'none'}), CONDS={conds}"
127
-
128
-
129
- demo = gr.Interface(
130
- fn=predict,
131
- inputs=[
132
- gr.Textbox(
133
- label="Natural Language Question",
134
- placeholder="e.g. Show all users older than 18",
135
- lines=2,
136
- ),
137
- gr.Textbox(
138
- label="Database Schema",
139
- placeholder="e.g. users: id, name, age, email",
140
- lines=2,
141
- ),
142
- gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Beam Size"),
143
- gr.Slider(minimum=64, maximum=512, value=256, step=64, label="Max Output Length"),
144
- ],
145
- outputs=[
146
- gr.Textbox(label="Generated SQL", lines=2),
147
- gr.Textbox(label="Raw Model Output", lines=1),
148
- gr.Textbox(label="Parsed Components", lines=1),
149
- ],
150
- title="Text-to-SQL (T5 Fine-tuned on WikiSQL)",
151
- description="Converts natural language questions to SQL. The model outputs structured tokens (SEL/AGG/CONDS) which are then converted to SQL using your schema.",
152
- examples=[
153
- # Table: 1-10015132-16 (Toronto Raptors players)
154
- ["What is terrence ross' nationality", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
155
- ["What clu was in toronto 1995-96", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
156
- ["which club was in toronto 2003-06", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
157
- ["how many schools or teams had jalen rose", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
158
- # Table: 1-10083598-1 (Racing)
159
- ["Where was Assen held?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
160
- ["What was the number of race that Kevin Curtain won?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
161
- ["What was the date of the race in Misano?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
162
- # Table: 1-1013129-2 (NHL Draft)
163
- ["How many different positions did Sherbrooke Faucons (qmjhl) provide in the draft?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
164
- ["What are the nationalities of the player picked from Thunder Bay Flyers (ushl)", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
165
- ["How many different college/junior/club teams provided a player to the Washington Capitals NHL Team?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
166
- # Table: 1-1013129-3 (NHL Draft)
167
- ["How many different nationalities do the players of New Jersey Devils come from?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
168
- ["What's Dorain Anneck's pick number?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
169
- ],
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  demo.launch()
 
14
  AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
15
  OPS = ["=", ">", "<", ">=", "<=", "!="]
16
 
17
+ CSS = """
18
+ .main-header {
19
+ text-align: center;
20
+ margin-bottom: 0.5rem;
21
+ }
22
+ .main-header h1 {
23
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
24
+ -webkit-background-clip: text;
25
+ -webkit-text-fill-color: transparent;
26
+ font-size: 2.4rem;
27
+ font-weight: 800;
28
+ margin-bottom: 0.25rem;
29
+ }
30
+ .main-header p {
31
+ color: #6b7280;
32
+ font-size: 1.05rem;
33
+ max-width: 600px;
34
+ margin: 0 auto;
35
+ }
36
+ .pipeline-box {
37
+ background: linear-gradient(135deg, #f0f4ff 0%, #faf0ff 100%);
38
+ border: 1px solid #e0d4f5;
39
+ border-radius: 12px;
40
+ padding: 1rem 1.5rem;
41
+ text-align: center;
42
+ font-family: 'SF Mono', 'Fira Code', monospace;
43
+ font-size: 0.9rem;
44
+ color: #4a4a6a;
45
+ margin-bottom: 1rem;
46
+ }
47
+ .sql-output textarea {
48
+ font-family: 'SF Mono', 'Fira Code', 'Cascadia Code', monospace !important;
49
+ font-size: 1.1rem !important;
50
+ background: #1e1e2e !important;
51
+ color: #cdd6f4 !important;
52
+ border: 1px solid #45475a !important;
53
+ border-radius: 10px !important;
54
+ padding: 1rem !important;
55
+ }
56
+ .raw-output textarea {
57
+ font-family: 'SF Mono', 'Fira Code', monospace !important;
58
+ font-size: 0.9rem !important;
59
+ background: #f8f9fc !important;
60
+ color: #6b7280 !important;
61
+ border: 1px solid #e5e7eb !important;
62
+ border-radius: 8px !important;
63
+ }
64
+ .input-section {
65
+ border: 1px solid #e5e7eb;
66
+ border-radius: 12px;
67
+ padding: 1.25rem;
68
+ background: #fafbff;
69
+ }
70
+ .generate-btn {
71
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
72
+ border: none !important;
73
+ color: white !important;
74
+ font-weight: 600 !important;
75
+ font-size: 1.05rem !important;
76
+ padding: 0.75rem 2rem !important;
77
+ border-radius: 10px !important;
78
+ min-height: 46px !important;
79
+ }
80
+ .generate-btn:hover {
81
+ opacity: 0.92 !important;
82
+ transform: translateY(-1px);
83
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
84
+ }
85
+ .info-badge {
86
+ display: inline-block;
87
+ background: #eef2ff;
88
+ color: #4f46e5;
89
+ padding: 0.2rem 0.6rem;
90
+ border-radius: 6px;
91
+ font-size: 0.8rem;
92
+ font-weight: 600;
93
+ }
94
+ footer { display: none !important; }
95
+ """
96
+
97
 
98
  def decode_structured_output(text):
 
99
  sel = agg = None
100
  conds = []
101
  try:
 
109
  cond_str = part[6:].strip()
110
  if cond_str:
111
  for c in cond_str.split(";"):
112
+ vals = c.split(",", 2)
113
  if len(vals) >= 3:
114
  conds.append([int(vals[0]), int(vals[1]), vals[2]])
115
  except Exception:
 
118
 
119
 
120
  def parse_schema(schema_str):
 
121
  schema_str = schema_str.strip()
122
  if not schema_str:
123
  return "table", []
 
 
124
  first_table = schema_str.split("|")[0].strip()
 
125
  if ":" in first_table:
126
  table_name, cols_str = first_table.split(":", 1)
127
  table_name = table_name.strip()
128
  columns = [c.strip() for c in cols_str.split(",") if c.strip()]
129
  else:
 
130
  table_name = "table"
131
  columns = [c.strip() for c in first_table.split(",") if c.strip()]
 
132
  return table_name, columns
133
 
134
 
135
  def structured_to_sql(sel, agg, conds, columns, table_name="table"):
 
136
  if sel is None or agg is None:
137
  return None
 
138
  col_name = columns[sel] if sel < len(columns) else f"col{sel}"
 
 
139
  if agg == 0:
140
  sql = f"SELECT {col_name} FROM {table_name}"
141
  else:
142
  agg_op = AGG_OPS[agg] if agg < len(AGG_OPS) else ""
143
  sql = f"SELECT {agg_op}({col_name}) FROM {table_name}"
 
 
144
  if conds:
145
  where_parts = []
146
  for c_idx, c_op, c_val in conds:
147
  c_name = columns[c_idx] if c_idx < len(columns) else f"col{c_idx}"
148
  op_str = OPS[c_op] if c_op < len(OPS) else "="
 
149
  try:
150
  float(c_val)
151
  val_sql = c_val
 
154
  where_parts.append(f"{c_name} {op_str} {val_sql}")
155
  if where_parts:
156
  sql += " WHERE " + " AND ".join(where_parts)
 
157
  return sql
158
 
159
 
160
+ def format_parsed(sel, agg, conds, columns):
161
+ parts = []
162
+ if sel is not None and sel < len(columns):
163
+ parts.append(f"Column: {columns[sel]} (index {sel})")
164
+ elif sel is not None:
165
+ parts.append(f"Column index: {sel}")
166
+ if agg is not None:
167
+ agg_label = AGG_OPS[agg] if agg < len(AGG_OPS) and agg > 0 else "None"
168
+ parts.append(f"Aggregation: {agg_label}")
169
+ if conds:
170
+ cond_strs = []
171
+ for c_idx, c_op, c_val in conds:
172
+ c_name = columns[c_idx] if c_idx < len(columns) else f"col{c_idx}"
173
+ op_str = OPS[c_op] if c_op < len(OPS) else "="
174
+ cond_strs.append(f"{c_name} {op_str} {c_val}")
175
+ parts.append(f"Conditions: {', '.join(cond_strs)}")
176
+ else:
177
+ parts.append("Conditions: None")
178
+ return " | ".join(parts)
179
+
180
+
181
+ def predict(question, schema, num_beams, max_length):
182
  if not question.strip():
183
+ return "", "", ""
184
 
185
  table_name, columns = parse_schema(schema)
 
186
  input_text = f"translate to SQL: {question}"
187
  if schema.strip():
188
  input_text += f" | schema: {schema.strip()}"
 
199
  )
200
 
201
  raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
202
  sel, agg, conds = decode_structured_output(raw_output)
203
 
204
  if sel is not None and agg is not None and columns:
205
  sql = structured_to_sql(sel, agg, conds, columns, table_name)
206
  else:
207
+ sql = "(Provide a schema to convert structured output to SQL)"
208
+
209
+ parsed = format_parsed(sel, agg, conds, columns) if sel is not None else ""
210
+ return sql, raw_output, parsed
211
+
212
+
213
+ theme = gr.themes.Soft(
214
+ primary_hue="indigo",
215
+ secondary_hue="purple",
216
+ neutral_hue="slate",
217
+ font=gr.themes.GoogleFont("Inter"),
218
+ font_mono=gr.themes.GoogleFont("Fira Code"),
219
+ ).set(
220
+ body_background_fill="#fafbff",
221
+ block_background_fill="white",
222
+ block_border_width="1px",
223
+ block_border_color="#e5e7eb",
224
+ block_radius="12px",
225
+ block_shadow="0 1px 3px rgba(0,0,0,0.06)",
226
+ input_border_color="#d1d5db",
227
+ input_border_width="1px",
228
+ input_radius="8px",
229
+ button_primary_background_fill="linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
230
+ button_primary_text_color="white",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  )
232
 
233
+ with gr.Blocks(theme=theme, css=CSS, title="Text-to-SQL Demo") as demo:
234
+ # Header
235
+ gr.HTML("""
236
+ <div class="main-header">
237
+ <h1>Text-to-SQL</h1>
238
+ <p>Fine-tuned T5 model that converts natural language questions into structured SQL queries using the WikiSQL dataset</p>
239
+ </div>
240
+ """)
241
+
242
+ # Pipeline visualization
243
+ gr.HTML("""
244
+ <div class="pipeline-box">
245
+ Natural Language &rarr; T5 Encoder &rarr; Structured Tokens (SEL | AGG | CONDS) &rarr; SQL Query
246
+ </div>
247
+ """)
248
+
249
+ with gr.Row(equal_height=True):
250
+ # Left: Inputs
251
+ with gr.Column(scale=1):
252
+ gr.Markdown("### Input")
253
+ question = gr.Textbox(
254
+ label="Natural Language Question",
255
+ placeholder="e.g. What is terrence ross' nationality?",
256
+ lines=2,
257
+ elem_classes=["input-section"],
258
+ )
259
+ schema = gr.Textbox(
260
+ label="Database Schema",
261
+ placeholder="table_name: col1, col2, col3, ...",
262
+ lines=2,
263
+ info="Format: table_name: column1, column2, column3",
264
+ elem_classes=["input-section"],
265
+ )
266
+ with gr.Row():
267
+ beams = gr.Slider(
268
+ minimum=1, maximum=10, value=5, step=1,
269
+ label="Beam Size",
270
+ info="Higher = better quality, slower",
271
+ )
272
+ max_len = gr.Slider(
273
+ minimum=64, maximum=512, value=256, step=64,
274
+ label="Max Length",
275
+ )
276
+ btn = gr.Button("Generate SQL", variant="primary", elem_classes=["generate-btn"])
277
+
278
+ # Right: Outputs
279
+ with gr.Column(scale=1):
280
+ gr.Markdown("### Output")
281
+ sql_out = gr.Textbox(
282
+ label="Generated SQL",
283
+ lines=3,
284
+ elem_classes=["sql-output"],
285
+ show_copy_button=True,
286
+ )
287
+ raw_out = gr.Textbox(
288
+ label="Raw Model Output (Structured Tokens)",
289
+ lines=1,
290
+ elem_classes=["raw-output"],
291
+ )
292
+ parsed_out = gr.Textbox(
293
+ label="Decoded Components",
294
+ lines=1,
295
+ elem_classes=["raw-output"],
296
+ )
297
+
298
+ btn.click(
299
+ fn=predict,
300
+ inputs=[question, schema, beams, max_len],
301
+ outputs=[sql_out, raw_out, parsed_out],
302
+ )
303
+
304
+ # Examples
305
+ gr.Markdown("### Try These Examples")
306
+ gr.Examples(
307
+ examples=[
308
+ ["What is terrence ross' nationality", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
309
+ ["how many schools or teams had jalen rose", "players: Player, No., Nationality, Position, Years in Toronto, School/Club Team", 5, 256],
310
+ ["What was the date of the race in Misano?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
311
+ ["What was the number of race that Kevin Curtain won?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
312
+ ["Where was Assen held?", "races: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report", 5, 256],
313
+ ["How many different positions did Sherbrooke Faucons (qmjhl) provide in the draft?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
314
+ ["What are the nationalities of the player picked from Thunder Bay Flyers (ushl)", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
315
+ ["How many different nationalities do the players of New Jersey Devils come from?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
316
+ ["What's Dorain Anneck's pick number?", "draft: Pick, Player, Position, Nationality, NHL team, College/junior/club team", 5, 256],
317
+ ],
318
+ inputs=[question, schema, beams, max_len],
319
+ outputs=[sql_out, raw_out, parsed_out],
320
+ fn=predict,
321
+ cache_examples=False,
322
+ )
323
+
324
+ # Footer info
325
+ gr.HTML("""
326
+ <div style="text-align:center; margin-top:1.5rem; padding:1rem; color:#9ca3af; font-size:0.85rem;">
327
+ <span class="info-badge">T5-base</span>&nbsp;
328
+ <span class="info-badge">WikiSQL</span>&nbsp;
329
+ <span class="info-badge">Seq2Seq</span>&nbsp;
330
+ <span class="info-badge">Structured Output</span>
331
+ <p style="margin-top:0.75rem;">
332
+ Model: <a href="https://huggingface.co/RealMati/t2sql_v6_structured" target="_blank" style="color:#667eea;">RealMati/t2sql_v6_structured</a>
333
+ </p>
334
+ </div>
335
+ """)
336
+
337
  demo.launch()