Spaces:
Sleeping
Sleeping
File size: 9,162 Bytes
d644e37 8206fba d9559f9 8206fba 2b76081 b6cc3d7 d644e37 8206fba d9559f9 8206fba d644e37 8206fba d644e37 bc3f5f2 8206fba d644e37 8206fba 1764d16 8206fba 1764d16 bc3f5f2 d644e37 8206fba 1764d16 bc3f5f2 8206fba bc3f5f2 8206fba bc3f5f2 b6cc3d7 bc3f5f2 a20429e 8206fba a20429e 8206fba b6cc3d7 8206fba d644e37 8206fba d644e37 b6cc3d7 d644e37 8206fba bc3f5f2 8206fba a20429e bc3f5f2 b6cc3d7 bc3f5f2 8206fba bc3f5f2 8206fba bc3f5f2 8206fba b6cc3d7 bc3f5f2 8206fba d644e37 b6cc3d7 8206fba d644e37 bc3f5f2 b6cc3d7 8206fba d644e37 b6cc3d7 8206fba b6cc3d7 8206fba d644e37 bc3f5f2 8206fba d644e37 8206fba d644e37 8206fba b30f1af d644e37 8206fba d644e37 8206fba d644e37 bc3f5f2 d644e37 b30f1af bb8417b 1764d16 b30f1af e112786 bc3f5f2 e112786 bc3f5f2 e112786 bc3f5f2 e112786 6b260b6 e112786 6b260b6 d5d175f e112786 d5d175f e112786 d5d175f bb8417b bc3f5f2 8206fba d644e37 b6cc3d7 d644e37 bb8417b b6cc3d7 bc3f5f2 d644e37 b6cc3d7 d644e37 bc3f5f2 b30f1af 1b9214e bc3f5f2 9713609 bc3f5f2 b30f1af bc3f5f2 bb8417b bc3f5f2 bb8417b bc3f5f2 1b9214e b30f1af bc3f5f2 b30f1af bc3f5f2 1b9214e bc3f5f2 1b9214e bc3f5f2 0ab6bac bc3f5f2 d644e37 8206fba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | import os
import sqlite3
import gradio as gr
import pandas as pd
import spaces
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# Import local prompt builders
from build_prompt import (
build_prompt_0shot,
build_prompt_1shot,
build_prompt_5shot,
)
# Import both versions of the dataset
from dataset_generator import (
questions_llmsql_1,
questions_llmsql_2,
tables_llmsql_1,
tables_llmsql_2,
)
from dataset_generator import split_info as split_info_v1
from dataset_generator import split_info as split_info_v2
from evaluate import evaluate_sample
# Global variables for caching
model = None
tokenizer = None
current_model_id = None
conn = None
current_db_path = None
# Mapping for dynamic access
DATASETS = {
"LLMSQL 1.0": {
"questions": questions_llmsql_1,
"tables": tables_llmsql_1,
"split_info": split_info_v1,
"repo": "llmsql-bench/llmsql-benchmark",
"folder": "llmsql_1.0",
},
"LLMSQL 2.0": {
"questions": questions_llmsql_2,
"tables": tables_llmsql_2,
"split_info": split_info_v2,
"repo": "llmsql-bench/llmsql-2.0",
"folder": "llmsql_2.0",
},
}
# =====================
# Initialization Logic
# =====================
def initialize_data(version):
"""Downloads and connects to the version-specific database."""
global conn, current_db_path
config = DATASETS[version]
APP_DIR = os.getcwd()
VER_DIR = os.path.join(APP_DIR, config["folder"])
os.makedirs(VER_DIR, exist_ok=True)
DB_FILE = os.path.join(VER_DIR, "sqlite_tables.db")
# Only reconnect if the version changed or conn is None
if current_db_path != DB_FILE:
if not os.path.exists(DB_FILE):
print(f"Downloading database for {version}...")
hf_hub_download(
repo_id=config["repo"],
repo_type="dataset",
filename="sqlite_tables.db",
local_dir=VER_DIR,
)
if conn:
conn.close()
conn = sqlite3.connect(DB_FILE, check_same_thread=False)
current_db_path = DB_FILE
def load_model_if_needed(model_id):
"""Handles switching models inside the GPU space."""
global model, tokenizer, current_model_id
if current_model_id == model_id and model is not None:
return
print(f"Loading model: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="cuda",
trust_remote_code=True,
)
current_model_id = model_id
print(f"Model {model_id} loaded successfully.")
few_shot_selection = {
"0": build_prompt_0shot,
"1": build_prompt_1shot,
"5": build_prompt_5shot,
}
# =====================
# Main Logic
# =====================
@spaces.GPU
def run_inference(version, model_id, question_idx, few_shots):
initialize_data(version)
load_model_if_needed(model_id)
dataset = DATASETS[version]
qs = dataset["questions"]
ts = dataset["tables"]
try:
idx = int(question_idx)
q_data = qs[idx]
except:
return "Invalid ID", "", "", None, None, None, False
question = q_data["question"]
ground_truth_sql = q_data["sql"]
table = ts.get(q_data["table_id"])
if not table:
return "Table data missing", "", "", None, None, None, False
example_row = table["rows"][0] if table["rows"] else []
raw_prompt = few_shot_selection[few_shots](
question, table["header"], table["types"], example_row
)
messages = [{"role": "user", "content": raw_prompt}]
text_input = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text_input], return_tensors="pt").to("cuda")
generated_ids = model.generate(
**model_inputs, max_new_tokens=512, temperature=0.0, do_sample=False
)
new_tokens = generated_ids[0][len(model_inputs.input_ids[0]) :]
completion = tokenizer.decode(new_tokens, skip_special_tokens=True)
is_match, mismatch_info, _ = evaluate_sample(
item={"question_id": idx, "completion": completion},
questions=qs,
conn=conn,
)
return (
question,
completion,
ground_truth_sql,
mismatch_info["prediction_results"],
mismatch_info["gold_results"],
pd.DataFrame(table["rows"], columns=table["header"]),
bool(is_match),
)
# =====================
# UI Helpers
# =====================
def get_range_display(version):
info_dict = DATASETS[version]["split_info"]
lines = []
for s, info in info_dict.items():
lines.append(
f"**{s.capitalize()}**: IDs {info.get('first')} to {info.get('last')} (Total: {info.get('count')})"
)
return "\n".join(lines)
with gr.Blocks(title="Text-to-SQL Debugger", theme=gr.themes.Soft()) as app:
gr.Markdown("## 🔍 Text-to-SQL Interactive Debugger")
with gr.Row():
version_dropdown = gr.Dropdown(
choices=["LLMSQL 1.0", "LLMSQL 2.0"],
value="LLMSQL 2.0",
label="Dataset Version",
scale=1,
)
model_dropdown = gr.Dropdown(
choices=["Qwen/Qwen2.5-1.5B-Instruct", "openai/gpt-oss-20b"],
value="Qwen/Qwen2.5-1.5B-Instruct",
label="1. Select Model",
scale=1,
)
few_shot_dropdown = gr.Dropdown(
choices=["0", "1", "5"],
value="5",
label="2. Few-shot Examples",
scale=1,
)
with gr.Row():
question_input = gr.Textbox(
label="3. Enter Question ID",
value="1",
lines=2,
placeholder="e.g. 15001",
scale=2,
min_width=200,
)
with gr.Column(scale=1):
range_md = gr.Markdown(
get_range_display("LLMSQL 2.0"),
line_breaks=True,
padding=True,
)
run_button = gr.Button("Run Inference", variant="primary")
question_box = gr.Textbox(
label="Natural Language Question", lines=2, interactive=False
)
with gr.Row():
generated_sql_box = gr.Code(label="Generated SQL", language="sql", lines=3)
gt_sql_box = gr.Code(label="Ground Truth SQL", language="sql", lines=3)
gr.Markdown("### Data Comparison")
with gr.Row():
generated_table = gr.Dataframe(label="Generated Result", type="pandas")
gt_table = gr.Dataframe(label="Ground Truth Result", type="pandas")
with gr.Accordion("See Full Source Table", open=False):
full_table = gr.Dataframe(label="Full Table Content", type="pandas")
def update_ui_on_version_or_id(version, q_id):
"""Updates range text and pre-loads question data when version or ID changes."""
dataset = DATASETS[version]
range_text = get_range_display(version)
try:
idx = int(q_id) if (q_id and str(q_id).isdigit()) else 1
q_data = dataset["questions"][idx]
table_id = q_data["table_id"]
raw_table = dataset["tables"].get(table_id, {})
df = pd.DataFrame(
raw_table.get("rows", []), columns=raw_table.get("header", [])
)
return (
q_data["question"],
q_data["sql"],
df,
gr.update(label="Generated SQL", value=""),
range_text,
)
except Exception:
return (
"ID not found in this version",
"",
pd.DataFrame(),
gr.update(label="Generated SQL"),
range_text,
)
def handle_inference(version, model, few_shot, q_id):
q_text, gen_sql, gt_sql, gen_df, gt_df, full_df, is_match = run_inference(
version, model, q_id, few_shot
)
status = (
"✅ Generated SQL (MATCH SUCCESS)"
if is_match
else "❌ Generated SQL (MATCH FAILED)"
)
return gr.update(label=status, value=gen_sql), gen_df, gt_df
# Event Listeners
version_dropdown.change(
update_ui_on_version_or_id,
inputs=[version_dropdown, question_input],
outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md],
)
question_input.change(
update_ui_on_version_or_id,
inputs=[version_dropdown, question_input],
outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md],
)
run_button.click(
handle_inference,
inputs=[version_dropdown, model_dropdown, few_shot_dropdown, question_input],
outputs=[generated_sql_box, generated_table, gt_table],
)
app.load(
update_ui_on_version_or_id,
inputs=[version_dropdown, question_input],
outputs=[question_box, gt_sql_box, full_table, generated_sql_box, range_md],
)
app.launch()
|