Gaurav-2273's picture
Create app.py
59dfc66 verified
import streamlit as st
import google.generativeai as genai
import sqlite3
import pandas as pd
import os
# 1. SETUP: Get API Key from Secrets
api_key = os.getenv("GEMINI_API_KEY")
if api_key:
genai.configure(api_key=api_key)
else:
st.error("Missing Gemini API Key. Go to Settings -> Secrets.")
# 2. DATABASE: Ensure it exists
if not os.path.exists("sales.db"):
import generate_db
generate_db.init_db()
# 3. HELPER: Get Schema (Context for the LLM)
def get_schema():
conn = sqlite3.connect('sales.db')
cursor = conn.cursor()
# Get all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_str = ""
for table in tables:
table_name = table[0]
# Get columns for each table
cursor.execute(f"PRAGMA table_info({table_name})")
columns = [col[1] for col in cursor.fetchall()]
schema_str += f"Table: {table_name}, Columns: {', '.join(columns)}\n"
conn.close()
return schema_str
# 4. HELPER: Run SQL safely
def run_query(sql):
conn = sqlite3.connect('sales.db')
try:
df = pd.read_sql_query(sql, conn)
return df, None
except Exception as e:
return None, str(e)
finally:
conn.close()
# 5. CORE AGENT LOGIC (Planner -> Executor -> Reflector)
def agentic_text_to_sql(user_query):
model = genai.GenerativeModel('gemini-1.5-flash')
schema = get_schema()
# ATTEMPT 1: Generate SQL
prompt_v1 = f"""
You are an expert SQL Data Analyst.
Schema: {schema}
Question: "{user_query}"
Output ONLY valid SQL. No markdown.
"""
response = model.generate_content(prompt_v1)
sql_candidate = response.text.replace('```sql', '').replace('```', '').strip()
st.toast(f"Trying: {sql_candidate}") # Show user what's happening
# EXECUTE
df, error = run_query(sql_candidate)
# REFLECTION LOOP (If error, fix it)
if error:
st.warning(f"Initial query failed: {error}. Attempting self-correction...")
reflection_prompt = f"""
The query "{sql_candidate}" failed with error: "{error}".
Schema: {schema}
Fix the SQL. Output ONLY the fixed SQL.
"""
response_v2 = model.generate_content(reflection_prompt)
fixed_sql = response_v2.text.replace('```sql', '').replace('```', '').strip()
st.success(f"Fixed Query: {fixed_sql}")
df, error_v2 = run_query(fixed_sql)
if error_v2:
return None, f"Could not fix query. Error: {error_v2}"
return df, None
# 6. UI: Streamlit
st.title("AI SQL Agent 🕵️‍♂️")
st.write("Ask questions about: `customers`, `products`, `orders`")
query = st.text_input("Ask a question:", "Who bought the most expensive product?")
if st.button("Run Agent"):
with st.spinner("Agent is thinking..."):
results, err = agentic_text_to_sql(query)
if err:
st.error(err)
else:
st.dataframe(results)