Spaces:
Sleeping
Sleeping
| import os | |
| import sqlite3 | |
| import gradio as gr | |
| import pandas as pd | |
| import spaces | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Import local prompt builders | |
| from build_prompt import ( | |
| build_prompt_0shot, | |
| build_prompt_1shot, | |
| build_prompt_5shot, | |
| ) | |
| # Import both versions of the dataset | |
| from dataset_generator import ( | |
| questions_llmsql_1, | |
| questions_llmsql_2, | |
| tables_llmsql_1, | |
| tables_llmsql_2, | |
| ) | |
| from dataset_generator import split_info as split_info_v1 | |
| from dataset_generator import split_info as split_info_v2 | |
| from evaluate import evaluate_sample | |
| # Global variables for caching | |
| model = None | |
| tokenizer = None | |
| current_model_id = None | |
| conn = None | |
| current_db_path = None | |
| # Mapping for dynamic access | |
| DATASETS = { | |
| "LLMSQL 1.0": { | |
| "questions": questions_llmsql_1, | |
| "tables": tables_llmsql_1, | |
| "split_info": split_info_v1, | |
| "repo": "llmsql-bench/llmsql-benchmark", | |
| "folder": "llmsql_1.0", | |
| }, | |
| "LLMSQL 2.0": { | |
| "questions": questions_llmsql_2, | |
| "tables": tables_llmsql_2, | |
| "split_info": split_info_v2, | |
| "repo": "llmsql-bench/llmsql-2.0", | |
| "folder": "llmsql_2.0", | |
| }, | |
| } | |
| # ===================== | |
| # Initialization Logic | |
| # ===================== | |
| def initialize_data(version): | |
| """Downloads and connects to the version-specific database.""" | |
| global conn, current_db_path | |
| config = DATASETS[version] | |
| APP_DIR = os.getcwd() | |
| VER_DIR = os.path.join(APP_DIR, config["folder"]) | |
| os.makedirs(VER_DIR, exist_ok=True) | |
| DB_FILE = os.path.join(VER_DIR, "sqlite_tables.db") | |
| # Only reconnect if the version changed or conn is None | |
| if current_db_path != DB_FILE: | |
| if not os.path.exists(DB_FILE): | |
| print(f"Downloading database for {version}...") | |
| hf_hub_download( | |
| repo_id=config["repo"], | |
| repo_type="dataset", | |
| filename="sqlite_tables.db", | |
| local_dir=VER_DIR, | |
| ) | |
| if conn: | |
| conn.close() | |
| conn = sqlite3.connect(DB_FILE, check_same_thread=False) | |
| current_db_path = DB_FILE | |
| def load_model_if_needed(model_id): | |
| """Handles switching models inside the GPU space.""" | |
| global model, tokenizer, current_model_id | |
| if current_model_id == model_id and model is not None: | |
| return | |
| print(f"Loading model: {model_id}...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| ) | |
| current_model_id = model_id | |
| print(f"Model {model_id} loaded successfully.") | |
| few_shot_selection = { | |
| "0": build_prompt_0shot, | |
| "1": build_prompt_1shot, | |
| "5": build_prompt_5shot, | |
| } | |
| # ===================== | |
| # Main Logic | |
| # ===================== | |
| def run_inference(version, model_id, question_idx, few_shots): | |
| initialize_data(version) | |
| load_model_if_needed(model_id) | |
| dataset = DATASETS[version] | |
| qs = dataset["questions"] | |
| ts = dataset["tables"] | |
| try: | |
| idx = int(question_idx) | |
| q_data = qs[idx] | |
| except: | |
| return "Invalid ID", "", "", None, None, None, False | |
| question = q_data["question"] | |
| ground_truth_sql = q_data["sql"] | |
| table = ts.get(q_data["table_id"]) | |
| if not table: | |
| return "Table data missing", "", "", None, None, None, False | |
| example_row = table["rows"][0] if table["rows"] else [] | |
| raw_prompt = few_shot_selection[few_shots]( | |
| question, table["header"], table["types"], example_row | |
| ) | |
| messages = [{"role": "user", "content": raw_prompt}] | |
| text_input = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text_input], return_tensors="pt").to("cuda") | |
| generated_ids = model.generate( | |
| **model_inputs, max_new_tokens=512, temperature=0.0, do_sample=False | |
| ) | |
| new_tokens = generated_ids[0][len(model_inputs.input_ids[0]) :] | |
| completion = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| is_match, mismatch_info, _ = evaluate_sample( | |
| item={"question_id": idx, "completion": completion}, | |
| questions=qs, | |
| conn=conn, | |
| ) | |
| return ( | |
| question, | |
| completion, | |
| ground_truth_sql, | |
| mismatch_info["prediction_results"], | |
| mismatch_info["gold_results"], | |
| pd.DataFrame(table["rows"], columns=table["header"]), | |
| bool(is_match), | |
| ) | |
| # ===================== | |
| # UI Helpers | |
| # ===================== | |
| def get_range_display(version): | |
| info_dict = DATASETS[version]["split_info"] | |
| lines = [] | |
| for s, info in info_dict.items(): | |
| lines.append( | |
| f"**{s.capitalize()}**: IDs {info.get('first')} to {info.get('last')} (Total: {info.get('count')})" | |
| ) | |
| return "\n".join(lines) | |
| with gr.Blocks(title="Text-to-SQL Debugger", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("## 🔍 Text-to-SQL Interactive Debugger") | |
| with gr.Row(): | |
| version_dropdown = gr.Dropdown( | |
| choices=["LLMSQL 1.0", "LLMSQL 2.0"], | |
| value="LLMSQL 2.0", | |
| label="Dataset Version", | |
| scale=1, | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=["Qwen/Qwen2.5-1.5B-Instruct", "openai/gpt-oss-20b"], | |
| value="Qwen/Qwen2.5-1.5B-Instruct", | |
| label="1. Select Model", | |
| scale=1, | |
| ) | |
| few_shot_dropdown = gr.Dropdown( | |
| choices=["0", "1", "5"], | |
| value="5", | |
| label="2. Few-shot Examples", | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="3. Enter Question ID", | |
| value="1", | |
| lines=2, | |
| placeholder="e.g. 15001", | |
| scale=2, | |
| min_width=200, | |
| ) | |
| with gr.Column(scale=1): | |
| range_md = gr.Markdown( | |
| get_range_display("LLMSQL 2.0"), | |
| line_breaks=True, | |
| padding=True, | |
| ) | |
| run_button = gr.Button("Run Inference", variant="primary") | |
| question_box = gr.Textbox( | |
| label="Natural Language Question", lines=2, interactive=False | |
| ) | |
| with gr.Row(): | |
| generated_sql_box = gr.Code(label="Generated SQL", language="sql", lines=3) | |
| gt_sql_box = gr.Code(label="Ground Truth SQL", language="sql", lines=3) | |
| gr.Markdown("### Data Comparison") | |
| with gr.Row(): | |
| generated_table = gr.Dataframe(label="Generated Result", type="pandas") | |
| gt_table = gr.Dataframe(label="Ground Truth Result", type="pandas") | |
| with gr.Accordion("See Full Source Table", open=False): | |
| full_table = gr.Dataframe(label="Full Table Content", type="pandas") | |
| def update_ui_on_version_or_id(version, q_id): | |
| """Updates range text and pre-loads question data when version or ID changes.""" | |
| dataset = DATASETS[version] | |
| range_text = get_range_display(version) | |
| try: | |
| idx = int(q_id) if (q_id and str(q_id).isdigit()) else 1 | |
| q_data = dataset["questions"][idx] | |
| table_id = q_data["table_id"] | |
| raw_table = dataset["tables"].get(table_id, {}) | |
| df = pd.DataFrame( | |
| raw_table.get("rows", []), columns=raw_table.get("header", []) | |
| ) | |
| return ( | |
| q_data["question"], | |
| q_data["sql"], | |
| df, | |
| gr.update(label="Generated SQL", value=""), | |
| range_text, | |
| ) | |
| except Exception: | |
| return ( | |
| "ID not found in this version", | |
| "", | |
| pd.DataFrame(), | |
| gr.update(label="Generated SQL"), | |
| range_text, | |
| ) | |
| def handle_inference(version, model, few_shot, q_id): | |
| q_text, gen_sql, gt_sql, gen_df, gt_df, full_df, is_match = run_inference( | |
| version, model, q_id, few_shot | |
| ) | |
| status = ( | |
| "✅ Generated SQL (MATCH SUCCESS)" | |
| if is_match | |
| else "❌ Generated SQL (MATCH FAILED)" | |
| ) | |
| return gr.update(label=status, value=gen_sql), gen_df, gt_df | |
| # Event Listeners | |
| version_dropdown.change( | |
| update_ui_on_version_or_id, | |
| inputs=[version_dropdown, question_input], | |
| outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md], | |
| ) | |
| question_input.change( | |
| update_ui_on_version_or_id, | |
| inputs=[version_dropdown, question_input], | |
| outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md], | |
| ) | |
| run_button.click( | |
| handle_inference, | |
| inputs=[version_dropdown, model_dropdown, few_shot_dropdown, question_input], | |
| outputs=[generated_sql_box, generated_table, gt_table], | |
| ) | |
| app.load( | |
| update_ui_on_version_or_id, | |
| inputs=[version_dropdown, question_input], | |
| outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md], | |
| ) | |
| app.launch() | |