File size: 3,997 Bytes
6f5d272
 
971daac
66bae08
6f5d272
 
 
66bae08
6f5d272
66bae08
 
6f5d272
 
 
66bae08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5d272
66bae08
 
 
 
 
 
 
 
 
 
 
 
6f5d272
66bae08
 
 
 
 
6f5d272
 
66bae08
 
6f5d272
66bae08
6f5d272
66bae08
 
 
 
 
6f5d272
66bae08
 
 
6f5d272
66bae08
 
6f5d272
66bae08
 
 
 
 
6f5d272
66bae08
 
 
6f5d272
 
 
66bae08
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5d272
66bae08
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
from google import genai
from io import StringIO
import prompts # Assuming prompts.py is in the same directory

st.set_page_config(page_title="SQL AI Assistant", layout="wide")

# --- 1. Sidebar: Configuration & Context ---
st.sidebar.title("๐Ÿ› ๏ธ Database Context")

# Dialect and DB Name
dialect = st.sidebar.selectbox("SQL Dialect", ["PostgreSQL", "MySQL", "SQLite", "BigQuery", "Snowflake"])
db_name = st.sidebar.text_input("Database Name", placeholder="e.g. production_db")

# File Uploader for Schema
st.sidebar.subheader("๐Ÿ“„ Upload Schema")
uploaded_file = st.sidebar.file_uploader("Upload .sql or .txt file", type=["sql", "txt"])

# Logic to handle schema input (Manual or Uploaded)
initial_schema = ""
if uploaded_file is not None:
    initial_schema = uploaded_file.getvalue().decode("utf-8")

schema = st.sidebar.text_area(
    "Table Schemas (DDL)", 
    value=initial_schema, 
    placeholder="CREATE TABLE users (id INT, name TEXT...)", 
    height=300
)

# Option to toggle explanations
show_explanation = st.sidebar.checkbox("Show Logic Explanation", value=False)

if st.sidebar.button("๐Ÿ—‘๏ธ Clear Chat History"):
    st.session_state.messages = []
    st.rerun()

# --- 2. Initialize Gemini Client ---
# Correcting the secret key name based on your error report
try:
    api_key = st.secrets["GOOGLE_API_KEY"]
    client = genai.Client(api_key=api_key)
except Exception as e:
    st.error("API Key not found. Please ensure GOOGLE_API_KEY is set in Hugging Face Secrets.")
    st.stop()

# --- 3. Main UI ---
st.title("๐Ÿค– Gemini SQL Generator")
st.caption("Generate, inspect, and download optimized SQL queries.")

# Initialize session state for messages and the last generated query
if "messages" not in st.session_state:
    st.session_state.messages = []
if "last_query" not in st.session_state:
    st.session_state.last_query = ""

# Display chat history
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        if msg["role"] == "assistant":
            st.code(msg["content"], language="sql")
        else:
            st.write(msg["content"])

# --- 4. Chat Input & Generation ---
if prompt := st.chat_input("Show me all orders from the last 30 days..."):
    # Add user message to state
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)

    # Format instructions from prompts.py
    # Note: Ensure your SYSTEM_INSTRUCTION in prompts.py has {dialect}, {db_name}, {schema}, and {explain}
    explain_text = "Provide a brief explanation after the code." if show_explanation else "Provide ONLY the code."
    
    full_system_msg = prompts.SYSTEM_INSTRUCTION.format(
        dialect=dialect, 
        db_name=db_name if db_name else "unspecified", 
        schema=schema if schema else "No specific schema provided.",
        explain=explain_text
    )

    with st.spinner("Generating SQL..."):
        try:
            response = client.models.generate_content(
                model="gemini-2.0-flash", # Using the latest flash model
                config={'system_instruction': full_system_msg},
                contents=prompts.USER_PROMPT_TEMPLATE.format(user_input=prompt)
            )
            
            sql_output = response.text
            st.session_state.last_query = sql_output
            
            # Add assistant response to state
            st.session_state.messages.append({"role": "assistant", "content": sql_output})
            with st.chat_message("assistant"):
                st.code(sql_output, language="sql")
        
        except Exception as e:
            st.error(f"Generation failed: {e}")

# --- 5. Download Action ---
if st.session_state.last_query:
    st.divider()
    st.download_button(
        label="๐Ÿ’พ Download Latest Query",
        data=st.session_state.last_query,
        file_name="generated_query.sql",
        mime="text/x-sql",
        use_container_width=True
    )