import re import os import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch import time MODEL_ID = "RealMati/t2sql_v6_structured" print(f"Loading model: {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) model.eval() print("Model loaded.") AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] AGG_LABELS = ["None", "MAX", "MIN", "COUNT", "SUM", "AVG"] OPS = ["=", ">", "<", ">=", "<=", "!="] css_path = os.path.join(os.path.dirname(__file__), "style.css") with open(css_path, "r") as f: CSS = f.read() def decode_structured_output(text): sel = agg = None conds = [] try: for part in text.strip().split(" | "): part = part.strip() if part.startswith("SEL:"): sel = int(part[4:].strip()) elif part.startswith("AGG:"): agg = int(part[4:].strip()) elif part.startswith("CONDS:"): cond_str = part[6:].strip() if cond_str: for c in cond_str.split(";"): vals = c.split(",", 2) if len(vals) >= 3: conds.append([int(vals[0]), int(vals[1]), vals[2]]) except Exception: pass return sel, agg, conds def parse_schema(schema_str): schema_str = schema_str.strip() if not schema_str: return "table", [] first_table = schema_str.split("|")[0].strip() if ":" in first_table: table_name, cols_str = first_table.split(":", 1) table_name = table_name.strip() columns = [c.strip() for c in cols_str.split(",") if c.strip()] else: table_name = "table" columns = [c.strip() for c in first_table.split(",") if c.strip()] return table_name, columns def quote_col(name): return f"`{name}`" if " " in name else name def structured_to_sql(sel, agg, conds, columns, table_name="table"): if sel is None or agg is None: return None col_name = quote_col(columns[sel] if sel < len(columns) else f"col{sel}") if agg == 0: sql = f"SELECT {col_name} FROM {table_name}" else: agg_op = AGG_OPS[agg] if agg < len(AGG_OPS) else "" sql = f"SELECT {agg_op}({col_name}) FROM {table_name}" if conds: where_parts = [] for c_idx, c_op, c_val in conds: c_name = quote_col(columns[c_idx] if c_idx < len(columns) else f"col{c_idx}") op_str = OPS[c_op] if c_op < len(OPS) else "=" try: float(c_val) val_sql = c_val except (ValueError, TypeError): val_sql = f"'{c_val}'" where_parts.append(f"{c_name} {op_str} {val_sql}") if where_parts: sql += " WHERE " + " AND ".join(where_parts) return sql def format_parsed(sel, agg, conds, columns): parts = [] if sel is not None and sel < len(columns): parts.append(f"Column: {columns[sel]} (index {sel})") elif sel is not None: parts.append(f"Column index: {sel}") if agg is not None: agg_label = AGG_LABELS[agg] if agg < len(AGG_LABELS) else str(agg) parts.append(f"Aggregation: {agg_label}") if conds: cond_strs = [] for c_idx, c_op, c_val in conds: c_name = columns[c_idx] if c_idx < len(columns) else f"col{c_idx}" op_str = OPS[c_op] if c_op < len(OPS) else "=" cond_strs.append(f"{c_name} {op_str} {c_val}") parts.append(f"Conditions: {' AND '.join(cond_strs)}") else: parts.append("Conditions: None") return " | ".join(parts) def predict(question, schema, num_beams, max_length): if not question or not question.strip(): return ( "-- Enter a question and schema, then click Generate SQL", "Waiting for input...", "No query submitted yet", "", ) table_name, columns = parse_schema(schema) if not columns: return ( "-- Please provide a schema\n-- Format: table_name: col1, col2, col3", "Schema required", "Cannot map indices without column names", "", ) input_text = f"translate to SQL: {question}" if schema.strip(): input_text += f" | schema: {schema.strip()}" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) t0 = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_length=int(max_length), num_beams=int(num_beams), early_stopping=True, do_sample=False, ) latency = time.time() - t0 raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) sel, agg, conds = decode_structured_output(raw_output) if sel is not None and agg is not None and columns: sql = structured_to_sql(sel, agg, conds, columns, table_name) else: sql = f"-- Could not parse model output\n-- Raw: {raw_output}" parsed = format_parsed(sel, agg, conds, columns) if sel is not None else "Parse failed" perf = f"Inference: {latency:.2f}s | Beams: {int(num_beams)} | Tokens: {inputs['input_ids'].shape[1]}" return sql, raw_output, parsed, perf theme = gr.themes.Base( primary_hue="blue", secondary_hue="purple", neutral_hue="gray", font=gr.themes.GoogleFont("Inter"), font_mono=gr.themes.GoogleFont("Fira Code"), ).set( body_background_fill="#0d1117", body_text_color="#e2e8f0", block_background_fill="#161b22", block_border_color="#1f2937", block_border_width="1px", block_label_text_color="#d1d5db", block_title_text_color="#f3f4f6", block_radius="12px", block_shadow="none", input_background_fill="#111827", input_border_color="#1f2937", input_border_width="1px", input_placeholder_color="#4b5563", input_radius="8px", slider_color="#3b82f6", button_primary_background_fill="linear-gradient(135deg, #3b82f6, #8b5cf6)", button_primary_text_color="#ffffff", button_secondary_background_fill="#111827", button_secondary_text_color="#d1d5db", button_secondary_border_color="#1f2937", border_color_primary="#1f2937", color_accent_soft="#111827", ) with gr.Blocks(title="Text-to-SQL | T5 on WikiSQL") as demo: # Compact header — one line title + badges + pipeline gr.HTML("""
Format: table: col1, col2, col3 — column order = index mapping
A T5-base encoder-decoder fine-tuned on WikiSQL. Instead of generating raw SQL, it outputs structured tokens — column indices and operator codes — which a deterministic decoder maps to actual SQL using the provided schema.
Question + schema concatenated:
translate to SQL: {question} | schema: {table}: {col1}, {col2}
Column order matters — the model references columns by 0-based index.
The encoder processes input, decoder generates structured tokens via beam search.
Output: SEL:{col} | AGG:{agg} | CONDS:{col},{op},{val}
col,op,value tuplesIndices mapped back to column names from schema. Operators converted to SQL. Result: a valid, executable query.
| Component | Index | Meaning |
|---|---|---|
| AGG | 0 | No aggregation |
| 1 | MAX | |
| 2 | MIN | |
| 3 | COUNT | |
| 4 | SUM | |
| 5 | AVG | |
| OP | 0 | = (equals) |
| 1 | > (greater than) | |
| 2 | < (less than) | |
| 3 | >= (greater or equal) | |
| 4 | <= (less or equal) | |
| 5 | != (not equal) |
translate to SQL:80,654 hand-annotated SQL queries across 24,241 Wikipedia tables. Single-table queries with SELECT, aggregation, and WHERE conditions.