Spaces:
No application file
No application file
| import streamlit as st | |
| import google.generativeai as genai | |
| import sqlite3 | |
| import pandas as pd | |
| import os | |
| # 1. SETUP: Get API Key from Secrets | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if api_key: | |
| genai.configure(api_key=api_key) | |
| else: | |
| st.error("Missing Gemini API Key. Go to Settings -> Secrets.") | |
| # 2. DATABASE: Ensure it exists | |
| if not os.path.exists("sales.db"): | |
| import generate_db | |
| generate_db.init_db() | |
| # 3. HELPER: Get Schema (Context for the LLM) | |
| def get_schema(): | |
| conn = sqlite3.connect('sales.db') | |
| cursor = conn.cursor() | |
| # Get all table names | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| schema_str = "" | |
| for table in tables: | |
| table_name = table[0] | |
| # Get columns for each table | |
| cursor.execute(f"PRAGMA table_info({table_name})") | |
| columns = [col[1] for col in cursor.fetchall()] | |
| schema_str += f"Table: {table_name}, Columns: {', '.join(columns)}\n" | |
| conn.close() | |
| return schema_str | |
| # 4. HELPER: Run SQL safely | |
| def run_query(sql): | |
| conn = sqlite3.connect('sales.db') | |
| try: | |
| df = pd.read_sql_query(sql, conn) | |
| return df, None | |
| except Exception as e: | |
| return None, str(e) | |
| finally: | |
| conn.close() | |
| # 5. CORE AGENT LOGIC (Planner -> Executor -> Reflector) | |
| def agentic_text_to_sql(user_query): | |
| model = genai.GenerativeModel('gemini-1.5-flash') | |
| schema = get_schema() | |
| # ATTEMPT 1: Generate SQL | |
| prompt_v1 = f""" | |
| You are an expert SQL Data Analyst. | |
| Schema: {schema} | |
| Question: "{user_query}" | |
| Output ONLY valid SQL. No markdown. | |
| """ | |
| response = model.generate_content(prompt_v1) | |
| sql_candidate = response.text.replace('```sql', '').replace('```', '').strip() | |
| st.toast(f"Trying: {sql_candidate}") # Show user what's happening | |
| # EXECUTE | |
| df, error = run_query(sql_candidate) | |
| # REFLECTION LOOP (If error, fix it) | |
| if error: | |
| st.warning(f"Initial query failed: {error}. Attempting self-correction...") | |
| reflection_prompt = f""" | |
| The query "{sql_candidate}" failed with error: "{error}". | |
| Schema: {schema} | |
| Fix the SQL. Output ONLY the fixed SQL. | |
| """ | |
| response_v2 = model.generate_content(reflection_prompt) | |
| fixed_sql = response_v2.text.replace('```sql', '').replace('```', '').strip() | |
| st.success(f"Fixed Query: {fixed_sql}") | |
| df, error_v2 = run_query(fixed_sql) | |
| if error_v2: | |
| return None, f"Could not fix query. Error: {error_v2}" | |
| return df, None | |
| # 6. UI: Streamlit | |
| st.title("AI SQL Agent 🕵️♂️") | |
| st.write("Ask questions about: `customers`, `products`, `orders`") | |
| query = st.text_input("Ask a question:", "Who bought the most expensive product?") | |
| if st.button("Run Agent"): | |
| with st.spinner("Agent is thinking..."): | |
| results, err = agentic_text_to_sql(query) | |
| if err: | |
| st.error(err) | |
| else: | |
| st.dataframe(results) |