import os import sqlite3 import gradio as gr import pandas as pd import spaces import torch from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM, AutoTokenizer # Import local prompt builders from build_prompt import ( build_prompt_0shot, build_prompt_1shot, build_prompt_5shot, ) # Import both versions of the dataset from dataset_generator import ( questions_llmsql_1, questions_llmsql_2, tables_llmsql_1, tables_llmsql_2, ) from dataset_generator import split_info as split_info_v1 from dataset_generator import split_info as split_info_v2 from evaluate import evaluate_sample # Global variables for caching model = None tokenizer = None current_model_id = None conn = None current_db_path = None # Mapping for dynamic access DATASETS = { "LLMSQL 1.0": { "questions": questions_llmsql_1, "tables": tables_llmsql_1, "split_info": split_info_v1, "repo": "llmsql-bench/llmsql-benchmark", "folder": "llmsql_1.0", }, "LLMSQL 2.0": { "questions": questions_llmsql_2, "tables": tables_llmsql_2, "split_info": split_info_v2, "repo": "llmsql-bench/llmsql-2.0", "folder": "llmsql_2.0", }, } # ===================== # Initialization Logic # ===================== def initialize_data(version): """Downloads and connects to the version-specific database.""" global conn, current_db_path config = DATASETS[version] APP_DIR = os.getcwd() VER_DIR = os.path.join(APP_DIR, config["folder"]) os.makedirs(VER_DIR, exist_ok=True) DB_FILE = os.path.join(VER_DIR, "sqlite_tables.db") # Only reconnect if the version changed or conn is None if current_db_path != DB_FILE: if not os.path.exists(DB_FILE): print(f"Downloading database for {version}...") hf_hub_download( repo_id=config["repo"], repo_type="dataset", filename="sqlite_tables.db", local_dir=VER_DIR, ) if conn: conn.close() conn = sqlite3.connect(DB_FILE, check_same_thread=False) current_db_path = DB_FILE def load_model_if_needed(model_id): """Handles switching models inside the GPU space.""" global model, tokenizer, current_model_id if current_model_id == model_id and model is not None: return print(f"Loading model: {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", device_map="cuda", trust_remote_code=True, ) current_model_id = model_id print(f"Model {model_id} loaded successfully.") few_shot_selection = { "0": build_prompt_0shot, "1": build_prompt_1shot, "5": build_prompt_5shot, } # ===================== # Main Logic # ===================== @spaces.GPU def run_inference(version, model_id, question_idx, few_shots): initialize_data(version) load_model_if_needed(model_id) dataset = DATASETS[version] qs = dataset["questions"] ts = dataset["tables"] try: idx = int(question_idx) q_data = qs[idx] except: return "Invalid ID", "", "", None, None, None, False question = q_data["question"] ground_truth_sql = q_data["sql"] table = ts.get(q_data["table_id"]) if not table: return "Table data missing", "", "", None, None, None, False example_row = table["rows"][0] if table["rows"] else [] raw_prompt = few_shot_selection[few_shots]( question, table["header"], table["types"], example_row ) messages = [{"role": "user", "content": raw_prompt}] text_input = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text_input], return_tensors="pt").to("cuda") generated_ids = model.generate( **model_inputs, max_new_tokens=512, temperature=0.0, do_sample=False ) new_tokens = generated_ids[0][len(model_inputs.input_ids[0]) :] completion = tokenizer.decode(new_tokens, skip_special_tokens=True) is_match, mismatch_info, _ = evaluate_sample( item={"question_id": idx, "completion": completion}, questions=qs, conn=conn, ) return ( question, completion, ground_truth_sql, mismatch_info["prediction_results"], mismatch_info["gold_results"], pd.DataFrame(table["rows"], columns=table["header"]), bool(is_match), ) # ===================== # UI Helpers # ===================== def get_range_display(version): info_dict = DATASETS[version]["split_info"] lines = [] for s, info in info_dict.items(): lines.append( f"**{s.capitalize()}**: IDs {info.get('first')} to {info.get('last')} (Total: {info.get('count')})" ) return "\n".join(lines) with gr.Blocks(title="Text-to-SQL Debugger", theme=gr.themes.Soft()) as app: gr.Markdown("## 🔍 Text-to-SQL Interactive Debugger") with gr.Row(): version_dropdown = gr.Dropdown( choices=["LLMSQL 1.0", "LLMSQL 2.0"], value="LLMSQL 2.0", label="Dataset Version", scale=1, ) model_dropdown = gr.Dropdown( choices=["Qwen/Qwen2.5-1.5B-Instruct", "openai/gpt-oss-20b"], value="Qwen/Qwen2.5-1.5B-Instruct", label="1. Select Model", scale=1, ) few_shot_dropdown = gr.Dropdown( choices=["0", "1", "5"], value="5", label="2. Few-shot Examples", scale=1, ) with gr.Row(): question_input = gr.Textbox( label="3. Enter Question ID", value="1", lines=2, placeholder="e.g. 15001", scale=2, min_width=200, ) with gr.Column(scale=1): range_md = gr.Markdown( get_range_display("LLMSQL 2.0"), line_breaks=True, padding=True, ) run_button = gr.Button("Run Inference", variant="primary") question_box = gr.Textbox( label="Natural Language Question", lines=2, interactive=False ) with gr.Row(): generated_sql_box = gr.Code(label="Generated SQL", language="sql", lines=3) gt_sql_box = gr.Code(label="Ground Truth SQL", language="sql", lines=3) gr.Markdown("### Data Comparison") with gr.Row(): generated_table = gr.Dataframe(label="Generated Result", type="pandas") gt_table = gr.Dataframe(label="Ground Truth Result", type="pandas") with gr.Accordion("See Full Source Table", open=False): full_table = gr.Dataframe(label="Full Table Content", type="pandas") def update_ui_on_version_or_id(version, q_id): """Updates range text and pre-loads question data when version or ID changes.""" dataset = DATASETS[version] range_text = get_range_display(version) try: idx = int(q_id) if (q_id and str(q_id).isdigit()) else 1 q_data = dataset["questions"][idx] table_id = q_data["table_id"] raw_table = dataset["tables"].get(table_id, {}) df = pd.DataFrame( raw_table.get("rows", []), columns=raw_table.get("header", []) ) return ( q_data["question"], q_data["sql"], df, gr.update(label="Generated SQL", value=""), range_text, ) except Exception: return ( "ID not found in this version", "", pd.DataFrame(), gr.update(label="Generated SQL"), range_text, ) def handle_inference(version, model, few_shot, q_id): q_text, gen_sql, gt_sql, gen_df, gt_df, full_df, is_match = run_inference( version, model, q_id, few_shot ) status = ( "✅ Generated SQL (MATCH SUCCESS)" if is_match else "❌ Generated SQL (MATCH FAILED)" ) return gr.update(label=status, value=gen_sql), gen_df, gt_df # Event Listeners version_dropdown.change( update_ui_on_version_or_id, inputs=[version_dropdown, question_input], outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md], ) question_input.change( update_ui_on_version_or_id, inputs=[version_dropdown, question_input], outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md], ) run_button.click( handle_inference, inputs=[version_dropdown, model_dropdown, few_shot_dropdown, question_input], outputs=[generated_sql_box, generated_table, gt_table], ) app.load( update_ui_on_version_or_id, inputs=[version_dropdown, question_input], outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md], ) app.launch()