Spaces:
Sleeping
Sleeping
File size: 3,676 Bytes
43ebb71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
examples = [
{
"description": "Simple Filtering",
"question": "Show me the names of staff members hired after 2020.",
"context": "CREATE TABLE staff (staff_id INTEGER, name VARCHAR, join_date DATE, position VARCHAR)",
"base_model": "SELECT name FROM employees WHERE hire_date > '2020-01-01' \n-- Error: Hallucinated table 'employees' and column 'hire_date'",
"my_model": "SELECT name FROM staff WHERE join_date > '2020-12-31'"
},
{
"description": "Column Ambiguity",
"question": "What is the budget of the Marketing department?",
"context": "CREATE TABLE department (dept_id INTEGER, dept_name VARCHAR, budget_in_billions INTEGER)",
"base_model": "SELECT budget FROM department WHERE name = 'Marketing' \n-- Error: Column 'budget' does not exist. It is 'budget_in_billions'.",
"my_model": "SELECT budget_in_billions FROM department WHERE dept_name = 'Marketing'"
},
{
"description": "Counting & Logic",
"question": "How many departments have more than 10 employees?",
"context": "CREATE TABLE department (name VARCHAR, num_employees INTEGER, ranking INTEGER)",
"base_model": "SELECT COUNT(*) FROM department WHERE employees > 10 \n-- Error: Wrong column name 'employees'.",
"my_model": "SELECT COUNT(*) FROM department WHERE num_employees > 10"
}
]
# --- THE APP LOGIC ---
def get_example(index):
# Keep index within bounds
safe_index = index % len(examples)
ex = examples[safe_index]
return (
safe_index,
f"Example {safe_index + 1}: {ex['description']}", # Label
ex["question"],
ex["context"],
ex["base_model"],
ex["my_model"]
)
# --- THE UI DESIGN ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π¦ Llama-3 Text-to-SQL: Fine-Tuning Showcase")
gr.Markdown(
"""
### The Problem
Base models (like Llama-3-8B) often "hallucinate" database column names or invent tables that don't exist.
### The Solution
I fine-tuned Llama-3-8B using **Unsloth & QLoRA** on the Spider dataset to learn **Schema Linking**.
"""
)
# State variable to track which example we are on
index_state = gr.State(value=0)
with gr.Row():
btn_next = gr.Button("π Show Next Example", variant="primary", scale=0)
lbl_status = gr.Label(value="Example 1: Simple Filtering", label="Current Scenario", scale=1)
with gr.Row():
with gr.Column():
gr.Markdown("### π₯ Input")
out_q = gr.Textbox(label="User Question", value=examples[0]["question"], interactive=False)
out_c = gr.Code(label="Database Schema (Context)", value=examples[0]["context"], language="sql", interactive=False)
with gr.Column():
gr.Markdown("### π€ Model Comparison")
out_base = gr.Code(label="β Base Llama-3 (Before)", value=examples[0]["base_model"], language="sql", interactive=False)
out_mine = gr.Code(label="β
My Fine-Tuned Model (After)", value=examples[0]["my_model"], language="sql", interactive=False)
# Click Event
btn_next.click(
fn=get_example,
inputs=[index_state],
outputs=[index_state, lbl_status, out_q, out_c, out_base, out_mine]
)
# Logic to increment index
def increment_index(i):
return i + 1
btn_next.click(fn=increment_index, inputs=[index_state], outputs=[index_state])
# Launch
demo.launch() |