pihull's picture
fix: markdown scale
d5d175f
import os
import sqlite3
import gradio as gr
import pandas as pd
import spaces
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# Import local prompt builders
from build_prompt import (
build_prompt_0shot,
build_prompt_1shot,
build_prompt_5shot,
)
# Import both versions of the dataset
from dataset_generator import (
questions_llmsql_1,
questions_llmsql_2,
tables_llmsql_1,
tables_llmsql_2,
)
from dataset_generator import split_info as split_info_v1
from dataset_generator import split_info as split_info_v2
from evaluate import evaluate_sample
# Global variables for caching
model = None
tokenizer = None
current_model_id = None
conn = None
current_db_path = None
# Mapping for dynamic access
DATASETS = {
"LLMSQL 1.0": {
"questions": questions_llmsql_1,
"tables": tables_llmsql_1,
"split_info": split_info_v1,
"repo": "llmsql-bench/llmsql-benchmark",
"folder": "llmsql_1.0",
},
"LLMSQL 2.0": {
"questions": questions_llmsql_2,
"tables": tables_llmsql_2,
"split_info": split_info_v2,
"repo": "llmsql-bench/llmsql-2.0",
"folder": "llmsql_2.0",
},
}
# =====================
# Initialization Logic
# =====================
def initialize_data(version):
"""Downloads and connects to the version-specific database."""
global conn, current_db_path
config = DATASETS[version]
APP_DIR = os.getcwd()
VER_DIR = os.path.join(APP_DIR, config["folder"])
os.makedirs(VER_DIR, exist_ok=True)
DB_FILE = os.path.join(VER_DIR, "sqlite_tables.db")
# Only reconnect if the version changed or conn is None
if current_db_path != DB_FILE:
if not os.path.exists(DB_FILE):
print(f"Downloading database for {version}...")
hf_hub_download(
repo_id=config["repo"],
repo_type="dataset",
filename="sqlite_tables.db",
local_dir=VER_DIR,
)
if conn:
conn.close()
conn = sqlite3.connect(DB_FILE, check_same_thread=False)
current_db_path = DB_FILE
def load_model_if_needed(model_id):
"""Handles switching models inside the GPU space."""
global model, tokenizer, current_model_id
if current_model_id == model_id and model is not None:
return
print(f"Loading model: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="cuda",
trust_remote_code=True,
)
current_model_id = model_id
print(f"Model {model_id} loaded successfully.")
few_shot_selection = {
"0": build_prompt_0shot,
"1": build_prompt_1shot,
"5": build_prompt_5shot,
}
# =====================
# Main Logic
# =====================
@spaces.GPU
def run_inference(version, model_id, question_idx, few_shots):
initialize_data(version)
load_model_if_needed(model_id)
dataset = DATASETS[version]
qs = dataset["questions"]
ts = dataset["tables"]
try:
idx = int(question_idx)
q_data = qs[idx]
except:
return "Invalid ID", "", "", None, None, None, False
question = q_data["question"]
ground_truth_sql = q_data["sql"]
table = ts.get(q_data["table_id"])
if not table:
return "Table data missing", "", "", None, None, None, False
example_row = table["rows"][0] if table["rows"] else []
raw_prompt = few_shot_selection[few_shots](
question, table["header"], table["types"], example_row
)
messages = [{"role": "user", "content": raw_prompt}]
text_input = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text_input], return_tensors="pt").to("cuda")
generated_ids = model.generate(
**model_inputs, max_new_tokens=512, temperature=0.0, do_sample=False
)
new_tokens = generated_ids[0][len(model_inputs.input_ids[0]) :]
completion = tokenizer.decode(new_tokens, skip_special_tokens=True)
is_match, mismatch_info, _ = evaluate_sample(
item={"question_id": idx, "completion": completion},
questions=qs,
conn=conn,
)
return (
question,
completion,
ground_truth_sql,
mismatch_info["prediction_results"],
mismatch_info["gold_results"],
pd.DataFrame(table["rows"], columns=table["header"]),
bool(is_match),
)
# =====================
# UI Helpers
# =====================
def get_range_display(version):
info_dict = DATASETS[version]["split_info"]
lines = []
for s, info in info_dict.items():
lines.append(
f"**{s.capitalize()}**: IDs {info.get('first')} to {info.get('last')} (Total: {info.get('count')})"
)
return "\n".join(lines)
with gr.Blocks(title="Text-to-SQL Debugger", theme=gr.themes.Soft()) as app:
gr.Markdown("## 🔍 Text-to-SQL Interactive Debugger")
with gr.Row():
version_dropdown = gr.Dropdown(
choices=["LLMSQL 1.0", "LLMSQL 2.0"],
value="LLMSQL 2.0",
label="Dataset Version",
scale=1,
)
model_dropdown = gr.Dropdown(
choices=["Qwen/Qwen2.5-1.5B-Instruct", "openai/gpt-oss-20b"],
value="Qwen/Qwen2.5-1.5B-Instruct",
label="1. Select Model",
scale=1,
)
few_shot_dropdown = gr.Dropdown(
choices=["0", "1", "5"],
value="5",
label="2. Few-shot Examples",
scale=1,
)
with gr.Row():
question_input = gr.Textbox(
label="3. Enter Question ID",
value="1",
lines=2,
placeholder="e.g. 15001",
scale=2,
min_width=200,
)
with gr.Column(scale=1):
range_md = gr.Markdown(
get_range_display("LLMSQL 2.0"),
line_breaks=True,
padding=True,
)
run_button = gr.Button("Run Inference", variant="primary")
question_box = gr.Textbox(
label="Natural Language Question", lines=2, interactive=False
)
with gr.Row():
generated_sql_box = gr.Code(label="Generated SQL", language="sql", lines=3)
gt_sql_box = gr.Code(label="Ground Truth SQL", language="sql", lines=3)
gr.Markdown("### Data Comparison")
with gr.Row():
generated_table = gr.Dataframe(label="Generated Result", type="pandas")
gt_table = gr.Dataframe(label="Ground Truth Result", type="pandas")
with gr.Accordion("See Full Source Table", open=False):
full_table = gr.Dataframe(label="Full Table Content", type="pandas")
def update_ui_on_version_or_id(version, q_id):
"""Updates range text and pre-loads question data when version or ID changes."""
dataset = DATASETS[version]
range_text = get_range_display(version)
try:
idx = int(q_id) if (q_id and str(q_id).isdigit()) else 1
q_data = dataset["questions"][idx]
table_id = q_data["table_id"]
raw_table = dataset["tables"].get(table_id, {})
df = pd.DataFrame(
raw_table.get("rows", []), columns=raw_table.get("header", [])
)
return (
q_data["question"],
q_data["sql"],
df,
gr.update(label="Generated SQL", value=""),
range_text,
)
except Exception:
return (
"ID not found in this version",
"",
pd.DataFrame(),
gr.update(label="Generated SQL"),
range_text,
)
def handle_inference(version, model, few_shot, q_id):
q_text, gen_sql, gt_sql, gen_df, gt_df, full_df, is_match = run_inference(
version, model, q_id, few_shot
)
status = (
"✅ Generated SQL (MATCH SUCCESS)"
if is_match
else "❌ Generated SQL (MATCH FAILED)"
)
return gr.update(label=status, value=gen_sql), gen_df, gt_df
# Event Listeners
version_dropdown.change(
update_ui_on_version_or_id,
inputs=[version_dropdown, question_input],
outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md],
)
question_input.change(
update_ui_on_version_or_id,
inputs=[version_dropdown, question_input],
outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md],
)
run_button.click(
handle_inference,
inputs=[version_dropdown, model_dropdown, few_shot_dropdown, question_input],
outputs=[generated_sql_box, generated_table, gt_table],
)
app.load(
update_ui_on_version_or_id,
inputs=[version_dropdown, question_input],
outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md],
)
app.launch()