text2sql-demo / app.py
tjhalanigrid's picture
Remove warmup and add strict error handling to catch infinite looping again
79c8c53
"""
GRADIO DEMO UI
NL → SQL → Result Table
"""
import gradio as gr
import pandas as pd
import re
import time
from src.text2sql_engine import get_engine
engine = get_engine()
# =========================
# SAMPLE QUESTIONS DATA
# =========================
SAMPLES = [
("Show 10 distinct employee first names.", "chinook_1"),
("Which artist has the most albums?", "chinook_1"),
("List all the tracks that belong to the 'Rock' genre.", "chinook_1"),
("What are the names of all the cities?", "flight_1"),
("Find the flight number and cost of the cheapest flight.", "flight_1"),
("List the airlines that fly out of New York.", "flight_1"),
("Which campus was opened between 1935 and 1939?", "csu_1"),
("Count the number of students in each department.", "college_2"),
("List the names of all clubs.", "club_1"),
("How many members does each club have?", "club_1"),
("Show the names of all cinemas.", "cinema"),
("Which cinema has the most screens?", "cinema")
]
SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
# =========================
# SQL EXPLAINER
# =========================
def explain_sql(sql):
explanation = "This SQL query retrieves information from the database."
sql_lower = sql.lower()
if "join" in sql_lower:
explanation += "\n• It combines data from multiple tables using JOIN."
if "where" in sql_lower:
explanation += "\n• It filters rows using a WHERE condition."
if "group by" in sql_lower:
explanation += "\n• It groups results using GROUP BY."
if "order by" in sql_lower:
explanation += "\n• It sorts the results using ORDER BY."
if "limit" in sql_lower:
explanation += "\n• It limits the number of returned rows."
return explanation
# =========================
# CORE FUNCTIONS
# =========================
def run_query(method, sample_q, custom_q, db_id):
# 1. Safely determine the question
question = sample_q if method == "💡 Pick a Sample" else custom_q
# 2. Validate inputs before hitting the engine
if not question or str(question).strip() == "":
return "", pd.DataFrame(), "⚠️ Please enter a question."
if not db_id or str(db_id).strip() == "":
return "", pd.DataFrame(), "⚠️ Please select a database."
start_time = time.time()
# 3. GIANT SAFETY NET to prevent infinite loading spinners
try:
result = engine.ask(str(question), str(db_id))
except Exception as e:
return "", pd.DataFrame(), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
final_sql = result.get("sql", "")
error_msg = result.get("error", None)
rows = result.get("rows", [])
cols = result.get("columns", [])
end_time = time.time()
latency = round(end_time - start_time, 3)
# 4. Handle SQL generation/execution errors
if error_msg:
return final_sql, pd.DataFrame(), f"❌ SQL Error:\n{error_msg}"
# 5. Handle Zero Rows gracefully
if not rows:
df = pd.DataFrame(columns=cols if cols else [])
explanation = f"✅ Query executed successfully\n\nRows returned: 0\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}"
return final_sql, df, explanation
# 6. Handle successful execution
df = pd.DataFrame(rows, columns=cols)
actual_rows = len(rows)
explanation = f"✅ Query executed successfully\n\nRows returned: {actual_rows}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}"
limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
if limit_match:
requested_limit = int(limit_match.group(1))
if actual_rows < requested_limit:
explanation += f"\n\nℹ️ Query allowed up to {requested_limit} rows but only {actual_rows} matched."
return final_sql, df, explanation
def toggle_input_method(method, current_sample):
if method == "💡 Pick a Sample":
# Find the DB matching the current sample (fallback to 'chinook_1')
db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
return (
gr.update(visible=True), # Show sample_dropdown
gr.update(visible=False), # Hide type_own_warning
gr.update(visible=False), # Hide custom_question
gr.update(value=db, interactive=False) # Lock and reset db_id
)
else:
return (
gr.update(visible=False), # Hide sample_dropdown
gr.update(visible=True), # Show type_own_warning
gr.update(visible=True), # Show custom_question
gr.update(interactive=True) # Unlock db_id
)
def load_sample(selected_question):
if not selected_question:
return gr.update()
db = next((db for q, db in SAMPLES if q == selected_question), "chinook_1")
return gr.update(value=db)
def clear_inputs():
return (
gr.update(value="💡 Pick a Sample"),
gr.update(value=SAMPLE_QUESTIONS[0], visible=True), # sample_dropdown
gr.update(visible=False), # type_own_warning
gr.update(value="", visible=False), # custom_question
gr.update(value="chinook_1", interactive=False), # db_id
"", pd.DataFrame(), "" # Outputs (SQL, Table, Explanation)
)
def update_schema(db_id):
if not db_id:
return ""
try:
raw_schema = engine.get_schema(db_id)
html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
for line in raw_schema.strip().split('\n'):
line = line.strip()
if not line: continue
match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
if match:
table_name = match.group(1).upper()
columns = match.group(2).lower()
html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{table_name}</strong> <span style='color: #64748b;'>( {columns} )</span></div>"
else:
html_output += f"<div style='color: #475569;'>{line}</div>"
html_output += "</div>"
return html_output
except Exception as e:
return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
# =========================
# UI LAYOUT
# =========================
with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL RLHF") as demo:
gr.HTML(
"""
<div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
<h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
<p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
</div>
"""
)
DBS = sorted([
"flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1",
"academic", "aircraft", "car_1", "cinema", "club_1", "csu_1",
"college_1", "college_2", "company_1", "company_employee",
"customer_complaints", "department_store", "employee_hire_evaluation",
"museum_visit", "products_for_hire", "restaurant_1",
"school_finance", "shop_membership", "small_bank_1",
"soccer_1", "student_1", "tvshow", "voter_1", "world_1"
])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Configuration & Input")
input_method = gr.Radio(
choices=["💡 Pick a Sample", "✍️ Type my own"],
value="💡 Pick a Sample",
label="How do you want to ask?"
)
# --- SAMPLE SECTION ---
sample_dropdown = gr.Dropdown(
choices=SAMPLE_QUESTIONS,
value=SAMPLE_QUESTIONS[0],
label="Select a Sample Question",
info="The database will be selected automatically.",
visible=True
)
# --- CUSTOM TYPE WARNING ---
type_own_warning = gr.Markdown(
"**⚠️ Please select a Database first, then type your custom question below:**",
visible=False
)
gr.Markdown("---")
# --- DATABASE SELECTION (Moved Up) ---
db_id = gr.Dropdown(
choices=DBS,
value="chinook_1",
label="Select Database",
interactive=False
)
# --- CUSTOM QUESTION BOX ---
custom_question = gr.Textbox(
label="Ask your Custom Question",
placeholder="Type your own question here...",
lines=3,
visible=False
)
gr.Markdown("#### 📋 Database Structure")
gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
schema_display = gr.HTML(value=update_schema("chinook_1"))
with gr.Row():
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
run_btn = gr.Button(" Generate & Run SQL", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 2. Execution Results")
final_sql = gr.Code(language="sql", label="Final Executed SQL")
result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
# =========================
# EVENT LISTENERS
# =========================
# Updated to handle the new Markdown warning toggle
input_method.change(
fn=toggle_input_method,
inputs=[input_method, sample_dropdown],
outputs=[sample_dropdown, type_own_warning, custom_question, db_id]
)
sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
run_btn.click(
fn=run_query,
inputs=[input_method, sample_dropdown, custom_question, db_id],
outputs=[final_sql, result_table, explanation]
)
clear_btn.click(
fn=clear_inputs,
inputs=[],
# Output list matches the updated clear_inputs() return values
outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation]
)
if __name__ == "__main__":
demo.launch()