Spaces:
Runtime error
Runtime error
| import sqlite3 | |
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| st.set_page_config(page_title="NL to SQL Agent") | |
| # Initialize or retrieve cached DB connection | |
| def init_db(): | |
| conn = sqlite3.connect("customer_product.db", check_same_thread=False) | |
| c = conn.cursor() | |
| # Create tables if they don't exist | |
| c.execute("""CREATE TABLE IF NOT EXISTS customers( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT, email TEXT, city TEXT)""") | |
| c.execute("""CREATE TABLE IF NOT EXISTS products( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT, price REAL)""") | |
| # Insert dummy data if tables are empty | |
| c.execute("SELECT COUNT(*) FROM customers") | |
| if c.fetchone()[0] == 0: | |
| customers = [ | |
| (1, "Alice", "alice@example.com", "New York"), | |
| (2, "Bob", "bob@example.com", "Los Angeles"), | |
| (3, "Carol", "carol@example.com", "Chicago") | |
| ] | |
| c.executemany("INSERT INTO customers VALUES (?,?,?,?)", customers) | |
| c.execute("SELECT COUNT(*) FROM products") | |
| if c.fetchone()[0] == 0: | |
| products = [ | |
| (1, "Widget", 9.99), | |
| (2, "Gizmo", 14.99), | |
| (3, "Doodad", 7.49) | |
| ] | |
| c.executemany("INSERT INTO products VALUES (?,?,?)", products) | |
| conn.commit() | |
| return conn | |
| conn = init_db() | |
| cursor = conn.cursor() | |
| # Load the Hugging Face model and tokenizer once (cached) | |
| def load_generator(): | |
| model_id = "microsoft/Phi-4-mini-flash-reasoning" # example; use any available LLM | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| device = 0 if model.device.type == 'cuda' else -1 | |
| gen = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device) | |
| return gen | |
| generator = load_generator() | |
| st.title("Natural Language to SQL Query App") | |
| st.write("Enter a request in plain English; the app will generate and run SQL on a sample database.") | |
| # User input | |
| question = st.text_area("Enter your query:", height=100) | |
| if st.button("Run Query"): | |
| if not question.strip(): | |
| st.error("Please enter a query.") | |
| else: | |
| # Agentic loop: try up to 3 attempts | |
| sql_query = None | |
| error_msg = None | |
| result_rows = None | |
| for attempt in range(3): | |
| # Construct prompt for the LLM | |
| if attempt == 0: | |
| prompt = ( | |
| "You are an assistant that converts English questions into SQL queries. " | |
| "The database schema is:\n" | |
| "Customers(id, name, email, city)\n" | |
| "Products(id, name, price)\n" | |
| f"Convert the request into an SQL query (SQLite syntax):\n\"\"\"\n{question}\n\"\"\"\n" | |
| "SQL Query:" | |
| ) | |
| else: | |
| prompt = ( | |
| f"The previous SQL query was:\n{sql_query}\n" | |
| f"It failed with error: {error_msg}\n" | |
| "Please provide a corrected SQL query.\n" | |
| "SQL Query:" | |
| ) | |
| # Generate SQL with the LLM | |
| output = generator(prompt, max_new_tokens=100, return_full_text=False) | |
| sql_query = output[0]["generated_text"].strip() | |
| # Try executing the SQL | |
| try: | |
| cursor.execute(sql_query) | |
| # Fetch results if it's a SELECT | |
| if sql_query.strip().lower().startswith("select"): | |
| result_rows = cursor.fetchall() | |
| else: | |
| conn.commit() | |
| result_rows = [] | |
| # Success: break loop | |
| break | |
| except sqlite3.Error as e: | |
| error_msg = str(e) | |
| if attempt == 2: | |
| st.error(f"SQL execution failed after 3 attempts: {error_msg}") | |
| # Display final SQL and results | |
| if sql_query: | |
| st.subheader("Generated SQL Query") | |
| st.code(sql_query) | |
| if result_rows is not None: | |
| st.subheader("Query Results") | |
| if len(result_rows) > 0: | |
| st.table(result_rows) | |
| else: | |
| st.write("(No results returned)") | |