Gaurav-2273 commited on
Commit
f76a063
·
verified ·
1 Parent(s): 24aab45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -65
app.py CHANGED
@@ -1,76 +1,113 @@
1
- import streamlit as st
2
- import google.generativeai as genai
3
  import sqlite3
4
- import pandas as pd
5
- import os
6
-
7
- # --- 1. SETTINGS ---
8
- st.set_page_config(page_title="SQL AI Agent", layout="wide")
9
-
10
- # Get Key from Secrets
11
- api_key = os.getenv("GEMINI_API_KEY")
12
 
13
- if api_key:
14
- genai.configure(api_key=api_key)
15
- else:
16
- st.error("⚠️ API Key missing! Go to Settings > Secrets and add 'GEMINI_API_KEY'")
17
 
18
- # --- 2. DATABASE SETUP ---
19
- DB_NAME = "sales.db"
20
  def init_db():
21
- conn = sqlite3.connect(DB_NAME)
22
  c = conn.cursor()
23
- c.execute('CREATE TABLE IF NOT EXISTS customers (id INTEGER PRIMARY KEY, name TEXT, region TEXT)')
24
- c.execute('CREATE TABLE IF NOT EXISTS products (id INTEGER PRIMARY KEY, name TEXT, category TEXT, price REAL)')
25
- c.execute("INSERT OR IGNORE INTO customers VALUES (1, 'Acme Corp', 'North'), (2, 'Globex', 'West')")
26
- c.execute("INSERT OR IGNORE INTO products VALUES (1, 'AI Widget', 'Software', 1000.0), (2, 'Cloud Server', 'Hardware', 5000.0)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  conn.commit()
28
- conn.close()
 
 
 
29
 
30
- if not os.path.exists(DB_NAME):
31
- init_db()
 
 
 
 
 
 
 
32
 
33
- # --- 3. AGENTIC LOGIC ---
34
- def sql_agent(user_prompt):
35
- # Forced stable model names
36
- model_names = ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-pro']
37
- model = None
38
-
39
- for name in model_names:
40
- try:
41
- model = genai.GenerativeModel(name)
42
- # Test if this model works
43
- model.generate_content("ping")
44
- break
45
- except:
46
- continue
47
-
48
- if not model:
49
- return None, "Could not connect to any Gemini models.", ""
50
 
51
- schema = "Tables: customers (id, name, region), products (id, name, category, price)"
52
- prompt = f"System: Output ONLY clean SQLite SQL. No backticks. Schema: {schema}\nUser: {user_prompt}"
53
-
54
- try:
55
- response = model.generate_content(prompt)
56
- sql = response.text.strip().replace('```sql', '').replace('```', '')
57
-
58
- conn = sqlite3.connect(DB_NAME)
59
- df = pd.read_sql_query(sql, conn)
60
- conn.close()
61
- return df, None, sql
62
- except Exception as e:
63
- return None, str(e), "Error"
64
 
65
- # --- 4. UI ---
66
- st.title("🕵️‍♂️ Agentic Text-to-SQL")
67
- query = st.text_input("Ask a question:", "Show me all products")
68
 
69
- if st.button("Run Analysis"):
70
- with st.spinner("Agent is querying..."):
71
- data, err, sql = sql_agent(query)
72
- if err:
73
- st.error(f"Error: {err}")
74
- else:
75
- st.code(sql, language="sql")
76
- st.dataframe(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import sqlite3
2
+ import streamlit as st
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
 
 
 
4
 
5
+ st.set_page_config(page_title="NL to SQL Agent")
 
 
 
6
 
7
+ # Initialize or retrieve cached DB connection
8
+ @st.cache_resource
9
  def init_db():
10
+ conn = sqlite3.connect("customer_product.db", check_same_thread=False)
11
  c = conn.cursor()
12
+ # Create tables if they don't exist
13
+ c.execute("""CREATE TABLE IF NOT EXISTS customers(
14
+ id INTEGER PRIMARY KEY,
15
+ name TEXT, email TEXT, city TEXT)""")
16
+ c.execute("""CREATE TABLE IF NOT EXISTS products(
17
+ id INTEGER PRIMARY KEY,
18
+ name TEXT, price REAL)""")
19
+ # Insert dummy data if tables are empty
20
+ c.execute("SELECT COUNT(*) FROM customers")
21
+ if c.fetchone()[0] == 0:
22
+ customers = [
23
+ (1, "Alice", "alice@example.com", "New York"),
24
+ (2, "Bob", "bob@example.com", "Los Angeles"),
25
+ (3, "Carol", "carol@example.com", "Chicago")
26
+ ]
27
+ c.executemany("INSERT INTO customers VALUES (?,?,?,?)", customers)
28
+ c.execute("SELECT COUNT(*) FROM products")
29
+ if c.fetchone()[0] == 0:
30
+ products = [
31
+ (1, "Widget", 9.99),
32
+ (2, "Gizmo", 14.99),
33
+ (3, "Doodad", 7.49)
34
+ ]
35
+ c.executemany("INSERT INTO products VALUES (?,?,?)", products)
36
  conn.commit()
37
+ return conn
38
+
39
+ conn = init_db()
40
+ cursor = conn.cursor()
41
 
42
+ # Load the Hugging Face model and tokenizer once (cached)
43
+ @st.cache_resource
44
+ def load_generator():
45
+ model_id = "microsoft/Phi-4-mini-flash-reasoning" # example; use any available LLM
46
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
47
+ model = AutoModelForCausalLM.from_pretrained(model_id)
48
+ device = 0 if model.device.type == 'cuda' else -1
49
+ gen = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
50
+ return gen
51
 
52
+ generator = load_generator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ st.title("Natural Language to SQL Query App")
55
+ st.write("Enter a request in plain English; the app will generate and run SQL on a sample database.")
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # User input
58
+ question = st.text_area("Enter your query:", height=100)
 
59
 
60
+ if st.button("Run Query"):
61
+ if not question.strip():
62
+ st.error("Please enter a query.")
63
+ else:
64
+ # Agentic loop: try up to 3 attempts
65
+ sql_query = None
66
+ error_msg = None
67
+ result_rows = None
68
+ for attempt in range(3):
69
+ # Construct prompt for the LLM
70
+ if attempt == 0:
71
+ prompt = (
72
+ "You are an assistant that converts English questions into SQL queries. "
73
+ "The database schema is:\n"
74
+ "Customers(id, name, email, city)\n"
75
+ "Products(id, name, price)\n"
76
+ f"Convert the request into an SQL query (SQLite syntax):\n\"\"\"\n{question}\n\"\"\"\n"
77
+ "SQL Query:"
78
+ )
79
+ else:
80
+ prompt = (
81
+ f"The previous SQL query was:\n{sql_query}\n"
82
+ f"It failed with error: {error_msg}\n"
83
+ "Please provide a corrected SQL query.\n"
84
+ "SQL Query:"
85
+ )
86
+ # Generate SQL with the LLM
87
+ output = generator(prompt, max_new_tokens=100, return_full_text=False)
88
+ sql_query = output[0]["generated_text"].strip()
89
+ # Try executing the SQL
90
+ try:
91
+ cursor.execute(sql_query)
92
+ # Fetch results if it's a SELECT
93
+ if sql_query.strip().lower().startswith("select"):
94
+ result_rows = cursor.fetchall()
95
+ else:
96
+ conn.commit()
97
+ result_rows = []
98
+ # Success: break loop
99
+ break
100
+ except sqlite3.Error as e:
101
+ error_msg = str(e)
102
+ if attempt == 2:
103
+ st.error(f"SQL execution failed after 3 attempts: {error_msg}")
104
+ # Display final SQL and results
105
+ if sql_query:
106
+ st.subheader("Generated SQL Query")
107
+ st.code(sql_query)
108
+ if result_rows is not None:
109
+ st.subheader("Query Results")
110
+ if len(result_rows) > 0:
111
+ st.table(result_rows)
112
+ else:
113
+ st.write("(No results returned)")