File size: 3,047 Bytes
59dfc66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import google.generativeai as genai
import sqlite3
import pandas as pd
import os

# 1. SETUP: Get API Key from Secrets
api_key = os.getenv("GEMINI_API_KEY") 
if api_key:
    genai.configure(api_key=api_key)
else:
    st.error("Missing Gemini API Key. Go to Settings -> Secrets.")

# 2. DATABASE: Ensure it exists
if not os.path.exists("sales.db"):
    import generate_db
    generate_db.init_db()

# 3. HELPER: Get Schema (Context for the LLM)
def get_schema():
    conn = sqlite3.connect('sales.db')
    cursor = conn.cursor()
    # Get all table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    schema_str = ""
    for table in tables:
        table_name = table[0]
        # Get columns for each table
        cursor.execute(f"PRAGMA table_info({table_name})")
        columns = [col[1] for col in cursor.fetchall()]
        schema_str += f"Table: {table_name}, Columns: {', '.join(columns)}\n"
    conn.close()
    return schema_str

# 4. HELPER: Run SQL safely
def run_query(sql):
    conn = sqlite3.connect('sales.db')
    try:
        df = pd.read_sql_query(sql, conn)
        return df, None
    except Exception as e:
        return None, str(e)
    finally:
        conn.close()

# 5. CORE AGENT LOGIC (Planner -> Executor -> Reflector)
def agentic_text_to_sql(user_query):
    model = genai.GenerativeModel('gemini-1.5-flash')
    schema = get_schema()
    
    # ATTEMPT 1: Generate SQL
    prompt_v1 = f"""
    You are an expert SQL Data Analyst.
    Schema: {schema}
    Question: "{user_query}"
    Output ONLY valid SQL. No markdown.
    """
    response = model.generate_content(prompt_v1)
    sql_candidate = response.text.replace('```sql', '').replace('```', '').strip()
    
    st.toast(f"Trying: {sql_candidate}") # Show user what's happening
    
    # EXECUTE
    df, error = run_query(sql_candidate)
    
    # REFLECTION LOOP (If error, fix it)
    if error:
        st.warning(f"Initial query failed: {error}. Attempting self-correction...")
        
        reflection_prompt = f"""
        The query "{sql_candidate}" failed with error: "{error}".
        Schema: {schema}
        Fix the SQL. Output ONLY the fixed SQL.
        """
        response_v2 = model.generate_content(reflection_prompt)
        fixed_sql = response_v2.text.replace('```sql', '').replace('```', '').strip()
        st.success(f"Fixed Query: {fixed_sql}")
        
        df, error_v2 = run_query(fixed_sql)
        if error_v2:
            return None, f"Could not fix query. Error: {error_v2}"
            
    return df, None

# 6. UI: Streamlit
st.title("AI SQL Agent 🕵️‍♂️")
st.write("Ask questions about: `customers`, `products`, `orders`")

query = st.text_input("Ask a question:", "Who bought the most expensive product?")

if st.button("Run Agent"):
    with st.spinner("Agent is thinking..."):
        results, err = agentic_text_to_sql(query)
        if err:
            st.error(err)
        else:
            st.dataframe(results)