""" GRADIO DEMO UI NL → SQL → Result Table """ import gradio as gr import pandas as pd import re import time from src.text2sql_engine import get_engine engine = get_engine() # ========================= # SAMPLE QUESTIONS DATA # ========================= SAMPLES = [ ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"), ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"), ("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"), ("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"), ("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"), ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema") ] SAMPLE_QUESTIONS = [q[0] for q in SAMPLES] # ========================= # SQL EXPLAINER # ========================= def explain_sql(sql): explanation = "This SQL query retrieves information from the database." sql_lower = sql.lower() if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN." if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition." if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY." if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY." if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows." return explanation # ========================= # CORE FUNCTIONS # ========================= def run_query(method, sample_q, custom_q, db_id): # 1. Safely determine the question question = sample_q if method == "💡 Pick a Sample" else custom_q # 2. Validate inputs before hitting the engine if not question or str(question).strip() == "": return "", pd.DataFrame(), "⚠️ Please enter a question." if not db_id or str(db_id).strip() == "": return "", pd.DataFrame(), "⚠️ Please select a database." start_time = time.time() # 3. GIANT SAFETY NET to prevent infinite loading spinners try: result = engine.ask(str(question), str(db_id)) except Exception as e: return "", pd.DataFrame(), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}" final_sql = result.get("sql", "") error_msg = result.get("error", None) rows = result.get("rows", []) cols = result.get("columns", []) end_time = time.time() latency = round(end_time - start_time, 3) # 4. Handle SQL generation/execution errors if error_msg: return final_sql, pd.DataFrame(), f"❌ SQL Error:\n{error_msg}" # 5. Handle Zero Rows gracefully if not rows: df = pd.DataFrame(columns=cols if cols else []) explanation = f"✅ Query executed successfully\n\nRows returned: 0\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}" return final_sql, df, explanation # 6. Handle successful execution df = pd.DataFrame(rows, columns=cols) actual_rows = len(rows) explanation = f"✅ Query executed successfully\n\nRows returned: {actual_rows}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}" limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE) if limit_match: requested_limit = int(limit_match.group(1)) if actual_rows < requested_limit: explanation += f"\n\nℹ️ Query allowed up to {requested_limit} rows but only {actual_rows} matched." return final_sql, df, explanation def toggle_input_method(method, current_sample): if method == "💡 Pick a Sample": # Find the DB matching the current sample (fallback to 'chinook_1') db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1") return ( gr.update(visible=True), # Show sample_dropdown gr.update(visible=False), # Hide type_own_warning gr.update(visible=False), # Hide custom_question gr.update(value=db, interactive=False) # Lock and reset db_id ) else: return ( gr.update(visible=False), # Hide sample_dropdown gr.update(visible=True), # Show type_own_warning gr.update(visible=True), # Show custom_question gr.update(interactive=True) # Unlock db_id ) def load_sample(selected_question): if not selected_question: return gr.update() db = next((db for q, db in SAMPLES if q == selected_question), "chinook_1") return gr.update(value=db) def clear_inputs(): return ( gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), # sample_dropdown gr.update(visible=False), # type_own_warning gr.update(value="", visible=False), # custom_question gr.update(value="chinook_1", interactive=False), # db_id "", pd.DataFrame(), "" # Outputs (SQL, Table, Explanation) ) def update_schema(db_id): if not db_id: return "" try: raw_schema = engine.get_schema(db_id) html_output = "
" for line in raw_schema.strip().split('\n'): line = line.strip() if not line: continue match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line) if match: table_name = match.group(1).upper() columns = match.group(2).lower() html_output += f"
{table_name} ( {columns} )
" else: html_output += f"
{line}
" html_output += "
" return html_output except Exception as e: return f"
Error loading schema: {str(e)}
" # ========================= # UI LAYOUT # ========================= with gr.Blocks(theme=gr.themes.Soft(), title="Text-to-SQL RLHF") as demo: gr.HTML( """

Text-to-SQL using RLHF + Execution Reward

Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.

""" ) DBS = sorted([ "flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "soccer_1", "student_1", "tvshow", "voter_1", "world_1" ]) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Configuration & Input") input_method = gr.Radio( choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?" ) # --- SAMPLE SECTION --- sample_dropdown = gr.Dropdown( choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True ) # --- CUSTOM TYPE WARNING --- type_own_warning = gr.Markdown( "**⚠️ Please select a Database first, then type your custom question below:**", visible=False ) gr.Markdown("---") # --- DATABASE SELECTION (Moved Up) --- db_id = gr.Dropdown( choices=DBS, value="chinook_1", label="Select Database", interactive=False ) # --- CUSTOM QUESTION BOX --- custom_question = gr.Textbox( label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False ) gr.Markdown("#### 📋 Database Structure") gr.HTML("

Use these exact names! Table names are Dark, Column names are Light.

") schema_display = gr.HTML(value=update_schema("chinook_1")) with gr.Row(): clear_btn = gr.Button("🗑️ Clear", variant="secondary") run_btn = gr.Button(" Generate & Run SQL", variant="primary") with gr.Column(scale=2): gr.Markdown("### 2. Execution Results") final_sql = gr.Code(language="sql", label="Final Executed SQL") result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True) explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8) # ========================= # EVENT LISTENERS # ========================= # Updated to handle the new Markdown warning toggle input_method.change( fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id] ) sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id]) db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display]) run_btn.click( fn=run_query, inputs=[input_method, sample_dropdown, custom_question, db_id], outputs=[final_sql, result_table, explanation] ) clear_btn.click( fn=clear_inputs, inputs=[], # Output list matches the updated clear_inputs() return values outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation] ) if __name__ == "__main__": demo.launch()