File size: 4,347 Bytes
1b640b4
f76a063
 
4a8cdc3
f76a063
1b640b4
f76a063
 
fcd041f
f76a063
fcd041f
f76a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcd041f
f76a063
 
 
 
fcd041f
f76a063
 
 
 
 
 
 
 
 
1b640b4
f76a063
6ea8ee7
f76a063
 
1b640b4
f76a063
 
fcd041f
f76a063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sqlite3
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

st.set_page_config(page_title="NL to SQL Agent")

# Initialize or retrieve cached DB connection
@st.cache_resource
def init_db():
    conn = sqlite3.connect("customer_product.db", check_same_thread=False)
    c = conn.cursor()
    # Create tables if they don't exist
    c.execute("""CREATE TABLE IF NOT EXISTS customers(
                    id INTEGER PRIMARY KEY,
                    name TEXT, email TEXT, city TEXT)""")
    c.execute("""CREATE TABLE IF NOT EXISTS products(
                    id INTEGER PRIMARY KEY,
                    name TEXT, price REAL)""")
    # Insert dummy data if tables are empty
    c.execute("SELECT COUNT(*) FROM customers")
    if c.fetchone()[0] == 0:
        customers = [
            (1, "Alice", "alice@example.com", "New York"),
            (2, "Bob", "bob@example.com", "Los Angeles"),
            (3, "Carol", "carol@example.com", "Chicago")
        ]
        c.executemany("INSERT INTO customers VALUES (?,?,?,?)", customers)
    c.execute("SELECT COUNT(*) FROM products")
    if c.fetchone()[0] == 0:
        products = [
            (1, "Widget", 9.99),
            (2, "Gizmo", 14.99),
            (3, "Doodad", 7.49)
        ]
        c.executemany("INSERT INTO products VALUES (?,?,?)", products)
    conn.commit()
    return conn

conn = init_db()
cursor = conn.cursor()

# Load the Hugging Face model and tokenizer once (cached)
@st.cache_resource
def load_generator():
    model_id = "microsoft/Phi-4-mini-flash-reasoning"  # example; use any available LLM
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    device = 0 if model.device.type == 'cuda' else -1
    gen = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
    return gen

generator = load_generator()

st.title("Natural Language to SQL Query App")
st.write("Enter a request in plain English; the app will generate and run SQL on a sample database.")

# User input
question = st.text_area("Enter your query:", height=100)

if st.button("Run Query"):
    if not question.strip():
        st.error("Please enter a query.")
    else:
        # Agentic loop: try up to 3 attempts
        sql_query = None
        error_msg = None
        result_rows = None
        for attempt in range(3):
            # Construct prompt for the LLM
            if attempt == 0:
                prompt = (
                    "You are an assistant that converts English questions into SQL queries. "
                    "The database schema is:\n"
                    "Customers(id, name, email, city)\n"
                    "Products(id, name, price)\n"
                    f"Convert the request into an SQL query (SQLite syntax):\n\"\"\"\n{question}\n\"\"\"\n"
                    "SQL Query:"
                )
            else:
                prompt = (
                    f"The previous SQL query was:\n{sql_query}\n"
                    f"It failed with error: {error_msg}\n"
                    "Please provide a corrected SQL query.\n"
                    "SQL Query:"
                )
            # Generate SQL with the LLM
            output = generator(prompt, max_new_tokens=100, return_full_text=False)
            sql_query = output[0]["generated_text"].strip()
            # Try executing the SQL
            try:
                cursor.execute(sql_query)
                # Fetch results if it's a SELECT
                if sql_query.strip().lower().startswith("select"):
                    result_rows = cursor.fetchall()
                else:
                    conn.commit()
                    result_rows = []
                # Success: break loop
                break
            except sqlite3.Error as e:
                error_msg = str(e)
                if attempt == 2:
                    st.error(f"SQL execution failed after 3 attempts: {error_msg}")
        # Display final SQL and results
        if sql_query:
            st.subheader("Generated SQL Query")
            st.code(sql_query)
        if result_rows is not None:
            st.subheader("Query Results")
            if len(result_rows) > 0:
                st.table(result_rows)
            else:
                st.write("(No results returned)")