text-to-sql / app.py
Gaurav-2273's picture
Update app.py
f76a063 verified
import sqlite3
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
st.set_page_config(page_title="NL to SQL Agent")
# Initialize or retrieve cached DB connection
@st.cache_resource
def init_db():
conn = sqlite3.connect("customer_product.db", check_same_thread=False)
c = conn.cursor()
# Create tables if they don't exist
c.execute("""CREATE TABLE IF NOT EXISTS customers(
id INTEGER PRIMARY KEY,
name TEXT, email TEXT, city TEXT)""")
c.execute("""CREATE TABLE IF NOT EXISTS products(
id INTEGER PRIMARY KEY,
name TEXT, price REAL)""")
# Insert dummy data if tables are empty
c.execute("SELECT COUNT(*) FROM customers")
if c.fetchone()[0] == 0:
customers = [
(1, "Alice", "alice@example.com", "New York"),
(2, "Bob", "bob@example.com", "Los Angeles"),
(3, "Carol", "carol@example.com", "Chicago")
]
c.executemany("INSERT INTO customers VALUES (?,?,?,?)", customers)
c.execute("SELECT COUNT(*) FROM products")
if c.fetchone()[0] == 0:
products = [
(1, "Widget", 9.99),
(2, "Gizmo", 14.99),
(3, "Doodad", 7.49)
]
c.executemany("INSERT INTO products VALUES (?,?,?)", products)
conn.commit()
return conn
conn = init_db()
cursor = conn.cursor()
# Load the Hugging Face model and tokenizer once (cached)
@st.cache_resource
def load_generator():
model_id = "microsoft/Phi-4-mini-flash-reasoning" # example; use any available LLM
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
device = 0 if model.device.type == 'cuda' else -1
gen = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
return gen
generator = load_generator()
st.title("Natural Language to SQL Query App")
st.write("Enter a request in plain English; the app will generate and run SQL on a sample database.")
# User input
question = st.text_area("Enter your query:", height=100)
if st.button("Run Query"):
if not question.strip():
st.error("Please enter a query.")
else:
# Agentic loop: try up to 3 attempts
sql_query = None
error_msg = None
result_rows = None
for attempt in range(3):
# Construct prompt for the LLM
if attempt == 0:
prompt = (
"You are an assistant that converts English questions into SQL queries. "
"The database schema is:\n"
"Customers(id, name, email, city)\n"
"Products(id, name, price)\n"
f"Convert the request into an SQL query (SQLite syntax):\n\"\"\"\n{question}\n\"\"\"\n"
"SQL Query:"
)
else:
prompt = (
f"The previous SQL query was:\n{sql_query}\n"
f"It failed with error: {error_msg}\n"
"Please provide a corrected SQL query.\n"
"SQL Query:"
)
# Generate SQL with the LLM
output = generator(prompt, max_new_tokens=100, return_full_text=False)
sql_query = output[0]["generated_text"].strip()
# Try executing the SQL
try:
cursor.execute(sql_query)
# Fetch results if it's a SELECT
if sql_query.strip().lower().startswith("select"):
result_rows = cursor.fetchall()
else:
conn.commit()
result_rows = []
# Success: break loop
break
except sqlite3.Error as e:
error_msg = str(e)
if attempt == 2:
st.error(f"SQL execution failed after 3 attempts: {error_msg}")
# Display final SQL and results
if sql_query:
st.subheader("Generated SQL Query")
st.code(sql_query)
if result_rows is not None:
st.subheader("Query Results")
if len(result_rows) > 0:
st.table(result_rows)
else:
st.write("(No results returned)")