File size: 4,343 Bytes
6492793
6f5d272
 
971daac
66bae08
88ddb81
 
 
 
2d32578
 
 
 
 
 
 
 
6f5d272
 
 
66bae08
6f5d272
66bae08
 
6f5d272
 
 
66bae08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5d272
66bae08
 
 
 
 
 
 
 
 
 
 
 
6f5d272
66bae08
 
 
 
 
6f5d272
 
66bae08
 
6f5d272
66bae08
6f5d272
66bae08
 
 
 
 
6f5d272
66bae08
 
 
6f5d272
66bae08
 
6f5d272
66bae08
 
 
 
 
6f5d272
66bae08
 
 
6f5d272
 
 
66bae08
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5d272
66bae08
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import streamlit as st
from google import genai
from io import StringIO
import prompts # Assuming prompts.py is in the same directory

api_key = os.environ.get("GOOGLE_API_KEY")

if not api_key:
    # Fallback for local testing (if you have a local secrets.toml)
    try:
        api_key = st.secrets["GOOGLE_API_KEY"]
    except:
        st.error("Missing GOOGLE_API_KEY. Please add it to your Space Secrets.")
        st.stop()

client = genai.Client(api_key=api_key)

st.set_page_config(page_title="SQL AI Assistant", layout="wide")

# --- 1. Sidebar: Configuration & Context ---
st.sidebar.title("๐Ÿ› ๏ธ Database Context")

# Dialect and DB Name
dialect = st.sidebar.selectbox("SQL Dialect", ["PostgreSQL", "MySQL", "SQLite", "BigQuery", "Snowflake"])
db_name = st.sidebar.text_input("Database Name", placeholder="e.g. production_db")

# File Uploader for Schema
st.sidebar.subheader("๐Ÿ“„ Upload Schema")
uploaded_file = st.sidebar.file_uploader("Upload .sql or .txt file", type=["sql", "txt"])

# Logic to handle schema input (Manual or Uploaded)
initial_schema = ""
if uploaded_file is not None:
    initial_schema = uploaded_file.getvalue().decode("utf-8")

schema = st.sidebar.text_area(
    "Table Schemas (DDL)", 
    value=initial_schema, 
    placeholder="CREATE TABLE users (id INT, name TEXT...)", 
    height=300
)

# Option to toggle explanations
show_explanation = st.sidebar.checkbox("Show Logic Explanation", value=False)

if st.sidebar.button("๐Ÿ—‘๏ธ Clear Chat History"):
    st.session_state.messages = []
    st.rerun()

# --- 2. Initialize Gemini Client ---
# Correcting the secret key name based on your error report
try:
    api_key = st.secrets["GOOGLE_API_KEY"]
    client = genai.Client(api_key=api_key)
except Exception as e:
    st.error("API Key not found. Please ensure GOOGLE_API_KEY is set in Hugging Face Secrets.")
    st.stop()

# --- 3. Main UI ---
st.title("๐Ÿค– Gemini SQL Generator")
st.caption("Generate, inspect, and download optimized SQL queries.")

# Initialize session state for messages and the last generated query
if "messages" not in st.session_state:
    st.session_state.messages = []
if "last_query" not in st.session_state:
    st.session_state.last_query = ""

# Display chat history
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        if msg["role"] == "assistant":
            st.code(msg["content"], language="sql")
        else:
            st.write(msg["content"])

# --- 4. Chat Input & Generation ---
if prompt := st.chat_input("Show me all orders from the last 30 days..."):
    # Add user message to state
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)

    # Format instructions from prompts.py
    # Note: Ensure your SYSTEM_INSTRUCTION in prompts.py has {dialect}, {db_name}, {schema}, and {explain}
    explain_text = "Provide a brief explanation after the code." if show_explanation else "Provide ONLY the code."
    
    full_system_msg = prompts.SYSTEM_INSTRUCTION.format(
        dialect=dialect, 
        db_name=db_name if db_name else "unspecified", 
        schema=schema if schema else "No specific schema provided.",
        explain=explain_text
    )

    with st.spinner("Generating SQL..."):
        try:
            response = client.models.generate_content(
                model="gemini-2.0-flash", # Using the latest flash model
                config={'system_instruction': full_system_msg},
                contents=prompts.USER_PROMPT_TEMPLATE.format(user_input=prompt)
            )
            
            sql_output = response.text
            st.session_state.last_query = sql_output
            
            # Add assistant response to state
            st.session_state.messages.append({"role": "assistant", "content": sql_output})
            with st.chat_message("assistant"):
                st.code(sql_output, language="sql")
        
        except Exception as e:
            st.error(f"Generation failed: {e}")

# --- 5. Download Action ---
if st.session_state.last_query:
    st.divider()
    st.download_button(
        label="๐Ÿ’พ Download Latest Query",
        data=st.session_state.last_query,
        file_name="generated_query.sql",
        mime="text/x-sql",
        use_container_width=True
    )