Spaces:
Sleeping
Sleeping
File size: 4,343 Bytes
6492793 6f5d272 971daac 66bae08 88ddb81 2d32578 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 6f5d272 66bae08 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import os
import streamlit as st
from google import genai
from io import StringIO
import prompts # Assuming prompts.py is in the same directory
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
# Fallback for local testing (if you have a local secrets.toml)
try:
api_key = st.secrets["GOOGLE_API_KEY"]
except:
st.error("Missing GOOGLE_API_KEY. Please add it to your Space Secrets.")
st.stop()
client = genai.Client(api_key=api_key)
st.set_page_config(page_title="SQL AI Assistant", layout="wide")
# --- 1. Sidebar: Configuration & Context ---
st.sidebar.title("๐ ๏ธ Database Context")
# Dialect and DB Name
dialect = st.sidebar.selectbox("SQL Dialect", ["PostgreSQL", "MySQL", "SQLite", "BigQuery", "Snowflake"])
db_name = st.sidebar.text_input("Database Name", placeholder="e.g. production_db")
# File Uploader for Schema
st.sidebar.subheader("๐ Upload Schema")
uploaded_file = st.sidebar.file_uploader("Upload .sql or .txt file", type=["sql", "txt"])
# Logic to handle schema input (Manual or Uploaded)
initial_schema = ""
if uploaded_file is not None:
initial_schema = uploaded_file.getvalue().decode("utf-8")
schema = st.sidebar.text_area(
"Table Schemas (DDL)",
value=initial_schema,
placeholder="CREATE TABLE users (id INT, name TEXT...)",
height=300
)
# Option to toggle explanations
show_explanation = st.sidebar.checkbox("Show Logic Explanation", value=False)
if st.sidebar.button("๐๏ธ Clear Chat History"):
st.session_state.messages = []
st.rerun()
# --- 2. Initialize Gemini Client ---
# Correcting the secret key name based on your error report
try:
api_key = st.secrets["GOOGLE_API_KEY"]
client = genai.Client(api_key=api_key)
except Exception as e:
st.error("API Key not found. Please ensure GOOGLE_API_KEY is set in Hugging Face Secrets.")
st.stop()
# --- 3. Main UI ---
st.title("๐ค Gemini SQL Generator")
st.caption("Generate, inspect, and download optimized SQL queries.")
# Initialize session state for messages and the last generated query
if "messages" not in st.session_state:
st.session_state.messages = []
if "last_query" not in st.session_state:
st.session_state.last_query = ""
# Display chat history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
if msg["role"] == "assistant":
st.code(msg["content"], language="sql")
else:
st.write(msg["content"])
# --- 4. Chat Input & Generation ---
if prompt := st.chat_input("Show me all orders from the last 30 days..."):
# Add user message to state
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# Format instructions from prompts.py
# Note: Ensure your SYSTEM_INSTRUCTION in prompts.py has {dialect}, {db_name}, {schema}, and {explain}
explain_text = "Provide a brief explanation after the code." if show_explanation else "Provide ONLY the code."
full_system_msg = prompts.SYSTEM_INSTRUCTION.format(
dialect=dialect,
db_name=db_name if db_name else "unspecified",
schema=schema if schema else "No specific schema provided.",
explain=explain_text
)
with st.spinner("Generating SQL..."):
try:
response = client.models.generate_content(
model="gemini-2.0-flash", # Using the latest flash model
config={'system_instruction': full_system_msg},
contents=prompts.USER_PROMPT_TEMPLATE.format(user_input=prompt)
)
sql_output = response.text
st.session_state.last_query = sql_output
# Add assistant response to state
st.session_state.messages.append({"role": "assistant", "content": sql_output})
with st.chat_message("assistant"):
st.code(sql_output, language="sql")
except Exception as e:
st.error(f"Generation failed: {e}")
# --- 5. Download Action ---
if st.session_state.last_query:
st.divider()
st.download_button(
label="๐พ Download Latest Query",
data=st.session_state.last_query,
file_name="generated_query.sql",
mime="text/x-sql",
use_container_width=True
) |