File size: 3,676 Bytes
43ebb71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr

examples = [
    {
        "description": "Simple Filtering",
        "question": "Show me the names of staff members hired after 2020.",
        "context": "CREATE TABLE staff (staff_id INTEGER, name VARCHAR, join_date DATE, position VARCHAR)",
        "base_model": "SELECT name FROM employees WHERE hire_date > '2020-01-01' \n-- Error: Hallucinated table 'employees' and column 'hire_date'",
        "my_model": "SELECT name FROM staff WHERE join_date > '2020-12-31'"
    },
    {
        "description": "Column Ambiguity",
        "question": "What is the budget of the Marketing department?",
        "context": "CREATE TABLE department (dept_id INTEGER, dept_name VARCHAR, budget_in_billions INTEGER)",
        "base_model": "SELECT budget FROM department WHERE name = 'Marketing' \n-- Error: Column 'budget' does not exist. It is 'budget_in_billions'.",
        "my_model": "SELECT budget_in_billions FROM department WHERE dept_name = 'Marketing'"
    },
    {
        "description": "Counting & Logic",
        "question": "How many departments have more than 10 employees?",
        "context": "CREATE TABLE department (name VARCHAR, num_employees INTEGER, ranking INTEGER)",
        "base_model": "SELECT COUNT(*) FROM department WHERE employees > 10 \n-- Error: Wrong column name 'employees'.",
        "my_model": "SELECT COUNT(*) FROM department WHERE num_employees > 10"
    }
]

# --- THE APP LOGIC ---
def get_example(index):
    # Keep index within bounds
    safe_index = index % len(examples)
    ex = examples[safe_index]
    return (
        safe_index, 
        f"Example {safe_index + 1}: {ex['description']}", # Label
        ex["question"], 
        ex["context"], 
        ex["base_model"], 
        ex["my_model"]
    )

# --- THE UI DESIGN ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸ¦™ Llama-3 Text-to-SQL: Fine-Tuning Showcase")
    gr.Markdown(
        """

        ### The Problem

        Base models (like Llama-3-8B) often "hallucinate" database column names or invent tables that don't exist.

        

        ### The Solution

        I fine-tuned Llama-3-8B using **Unsloth & QLoRA** on the Spider dataset to learn **Schema Linking**.

        """
    )
    
    # State variable to track which example we are on
    index_state = gr.State(value=0)

    with gr.Row():
        btn_next = gr.Button("πŸ‘‰ Show Next Example", variant="primary", scale=0)
        lbl_status = gr.Label(value="Example 1: Simple Filtering", label="Current Scenario", scale=1)

    with gr.Row():
        with gr.Column():
            gr.Markdown("### πŸ“₯ Input")
            out_q = gr.Textbox(label="User Question", value=examples[0]["question"], interactive=False)
            out_c = gr.Code(label="Database Schema (Context)", value=examples[0]["context"], language="sql", interactive=False)
        
        with gr.Column():
            gr.Markdown("### πŸ€– Model Comparison")
            out_base = gr.Code(label="❌ Base Llama-3 (Before)", value=examples[0]["base_model"], language="sql", interactive=False)
            out_mine = gr.Code(label="βœ… My Fine-Tuned Model (After)", value=examples[0]["my_model"], language="sql", interactive=False)

    # Click Event
    btn_next.click(
        fn=get_example,
        inputs=[index_state], 
        outputs=[index_state, lbl_status, out_q, out_c, out_base, out_mine]
    )
    
    # Logic to increment index
    def increment_index(i):
        return i + 1
    
    btn_next.click(fn=increment_index, inputs=[index_state], outputs=[index_state])

# Launch
demo.launch()