Karthix1's picture
initial commit
43ebb71 verified
import gradio as gr
examples = [
{
"description": "Simple Filtering",
"question": "Show me the names of staff members hired after 2020.",
"context": "CREATE TABLE staff (staff_id INTEGER, name VARCHAR, join_date DATE, position VARCHAR)",
"base_model": "SELECT name FROM employees WHERE hire_date > '2020-01-01' \n-- Error: Hallucinated table 'employees' and column 'hire_date'",
"my_model": "SELECT name FROM staff WHERE join_date > '2020-12-31'"
},
{
"description": "Column Ambiguity",
"question": "What is the budget of the Marketing department?",
"context": "CREATE TABLE department (dept_id INTEGER, dept_name VARCHAR, budget_in_billions INTEGER)",
"base_model": "SELECT budget FROM department WHERE name = 'Marketing' \n-- Error: Column 'budget' does not exist. It is 'budget_in_billions'.",
"my_model": "SELECT budget_in_billions FROM department WHERE dept_name = 'Marketing'"
},
{
"description": "Counting & Logic",
"question": "How many departments have more than 10 employees?",
"context": "CREATE TABLE department (name VARCHAR, num_employees INTEGER, ranking INTEGER)",
"base_model": "SELECT COUNT(*) FROM department WHERE employees > 10 \n-- Error: Wrong column name 'employees'.",
"my_model": "SELECT COUNT(*) FROM department WHERE num_employees > 10"
}
]
# --- THE APP LOGIC ---
def get_example(index):
# Keep index within bounds
safe_index = index % len(examples)
ex = examples[safe_index]
return (
safe_index,
f"Example {safe_index + 1}: {ex['description']}", # Label
ex["question"],
ex["context"],
ex["base_model"],
ex["my_model"]
)
# --- THE UI DESIGN ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ¦™ Llama-3 Text-to-SQL: Fine-Tuning Showcase")
gr.Markdown(
"""
### The Problem
Base models (like Llama-3-8B) often "hallucinate" database column names or invent tables that don't exist.
### The Solution
I fine-tuned Llama-3-8B using **Unsloth & QLoRA** on the Spider dataset to learn **Schema Linking**.
"""
)
# State variable to track which example we are on
index_state = gr.State(value=0)
with gr.Row():
btn_next = gr.Button("πŸ‘‰ Show Next Example", variant="primary", scale=0)
lbl_status = gr.Label(value="Example 1: Simple Filtering", label="Current Scenario", scale=1)
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“₯ Input")
out_q = gr.Textbox(label="User Question", value=examples[0]["question"], interactive=False)
out_c = gr.Code(label="Database Schema (Context)", value=examples[0]["context"], language="sql", interactive=False)
with gr.Column():
gr.Markdown("### πŸ€– Model Comparison")
out_base = gr.Code(label="❌ Base Llama-3 (Before)", value=examples[0]["base_model"], language="sql", interactive=False)
out_mine = gr.Code(label="βœ… My Fine-Tuned Model (After)", value=examples[0]["my_model"], language="sql", interactive=False)
# Click Event
btn_next.click(
fn=get_example,
inputs=[index_state],
outputs=[index_state, lbl_status, out_q, out_c, out_base, out_mine]
)
# Logic to increment index
def increment_index(i):
return i + 1
btn_next.click(fn=increment_index, inputs=[index_state], outputs=[index_state])
# Launch
demo.launch()