import streamlit as st import google.generativeai as genai import sqlite3 import pandas as pd import os # 1. SETUP: Get API Key from Secrets api_key = os.getenv("GEMINI_API_KEY") if api_key: genai.configure(api_key=api_key) else: st.error("Missing Gemini API Key. Go to Settings -> Secrets.") # 2. DATABASE: Ensure it exists if not os.path.exists("sales.db"): import generate_db generate_db.init_db() # 3. HELPER: Get Schema (Context for the LLM) def get_schema(): conn = sqlite3.connect('sales.db') cursor = conn.cursor() # Get all table names cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() schema_str = "" for table in tables: table_name = table[0] # Get columns for each table cursor.execute(f"PRAGMA table_info({table_name})") columns = [col[1] for col in cursor.fetchall()] schema_str += f"Table: {table_name}, Columns: {', '.join(columns)}\n" conn.close() return schema_str # 4. HELPER: Run SQL safely def run_query(sql): conn = sqlite3.connect('sales.db') try: df = pd.read_sql_query(sql, conn) return df, None except Exception as e: return None, str(e) finally: conn.close() # 5. CORE AGENT LOGIC (Planner -> Executor -> Reflector) def agentic_text_to_sql(user_query): model = genai.GenerativeModel('gemini-1.5-flash') schema = get_schema() # ATTEMPT 1: Generate SQL prompt_v1 = f""" You are an expert SQL Data Analyst. Schema: {schema} Question: "{user_query}" Output ONLY valid SQL. No markdown. """ response = model.generate_content(prompt_v1) sql_candidate = response.text.replace('```sql', '').replace('```', '').strip() st.toast(f"Trying: {sql_candidate}") # Show user what's happening # EXECUTE df, error = run_query(sql_candidate) # REFLECTION LOOP (If error, fix it) if error: st.warning(f"Initial query failed: {error}. Attempting self-correction...") reflection_prompt = f""" The query "{sql_candidate}" failed with error: "{error}". Schema: {schema} Fix the SQL. Output ONLY the fixed SQL. """ response_v2 = model.generate_content(reflection_prompt) fixed_sql = response_v2.text.replace('```sql', '').replace('```', '').strip() st.success(f"Fixed Query: {fixed_sql}") df, error_v2 = run_query(fixed_sql) if error_v2: return None, f"Could not fix query. Error: {error_v2}" return df, None # 6. UI: Streamlit st.title("AI SQL Agent 🕵️‍♂️") st.write("Ask questions about: `customers`, `products`, `orders`") query = st.text_input("Ask a question:", "Who bought the most expensive product?") if st.button("Run Agent"): with st.spinner("Agent is thinking..."): results, err = agentic_text_to_sql(query) if err: st.error(err) else: st.dataframe(results)