Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from google import genai | |
| from io import StringIO | |
| import prompts # Assuming prompts.py is in the same directory | |
| st.set_page_config(page_title="SQL AI Assistant", layout="wide") | |
| # --- 1. Sidebar: Configuration & Context --- | |
| st.sidebar.title("π οΈ Database Context") | |
| # Dialect and DB Name | |
| dialect = st.sidebar.selectbox("SQL Dialect", ["PostgreSQL", "MySQL", "SQLite", "BigQuery", "Snowflake"]) | |
| db_name = st.sidebar.text_input("Database Name", placeholder="e.g. production_db") | |
| # File Uploader for Schema | |
| st.sidebar.subheader("π Upload Schema") | |
| uploaded_file = st.sidebar.file_uploader("Upload .sql or .txt file", type=["sql", "txt"]) | |
| # Logic to handle schema input (Manual or Uploaded) | |
| initial_schema = "" | |
| if uploaded_file is not None: | |
| initial_schema = uploaded_file.getvalue().decode("utf-8") | |
| schema = st.sidebar.text_area( | |
| "Table Schemas (DDL)", | |
| value=initial_schema, | |
| placeholder="CREATE TABLE users (id INT, name TEXT...)", | |
| height=300 | |
| ) | |
| # Option to toggle explanations | |
| show_explanation = st.sidebar.checkbox("Show Logic Explanation", value=False) | |
| if st.sidebar.button("ποΈ Clear Chat History"): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| # --- 2. Initialize Gemini Client --- | |
| # Correcting the secret key name based on your error report | |
| try: | |
| api_key = st.secrets["GOOGLE_API_KEY"] | |
| client = genai.Client(api_key=api_key) | |
| except Exception as e: | |
| st.error("API Key not found. Please ensure GOOGLE_API_KEY is set in Hugging Face Secrets.") | |
| st.stop() | |
| # --- 3. Main UI --- | |
| st.title("π€ Gemini SQL Generator") | |
| st.caption("Generate, inspect, and download optimized SQL queries.") | |
| # Initialize session state for messages and the last generated query | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "last_query" not in st.session_state: | |
| st.session_state.last_query = "" | |
| # Display chat history | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| if msg["role"] == "assistant": | |
| st.code(msg["content"], language="sql") | |
| else: | |
| st.write(msg["content"]) | |
| # --- 4. Chat Input & Generation --- | |
| if prompt := st.chat_input("Show me all orders from the last 30 days..."): | |
| # Add user message to state | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| # Format instructions from prompts.py | |
| # Note: Ensure your SYSTEM_INSTRUCTION in prompts.py has {dialect}, {db_name}, {schema}, and {explain} | |
| explain_text = "Provide a brief explanation after the code." if show_explanation else "Provide ONLY the code." | |
| full_system_msg = prompts.SYSTEM_INSTRUCTION.format( | |
| dialect=dialect, | |
| db_name=db_name if db_name else "unspecified", | |
| schema=schema if schema else "No specific schema provided.", | |
| explain=explain_text | |
| ) | |
| with st.spinner("Generating SQL..."): | |
| try: | |
| response = client.models.generate_content( | |
| model="gemini-2.0-flash", # Using the latest flash model | |
| config={'system_instruction': full_system_msg}, | |
| contents=prompts.USER_PROMPT_TEMPLATE.format(user_input=prompt) | |
| ) | |
| sql_output = response.text | |
| st.session_state.last_query = sql_output | |
| # Add assistant response to state | |
| st.session_state.messages.append({"role": "assistant", "content": sql_output}) | |
| with st.chat_message("assistant"): | |
| st.code(sql_output, language="sql") | |
| except Exception as e: | |
| st.error(f"Generation failed: {e}") | |
| # --- 5. Download Action --- | |
| if st.session_state.last_query: | |
| st.divider() | |
| st.download_button( | |
| label="πΎ Download Latest Query", | |
| data=st.session_state.last_query, | |
| file_name="generated_query.sql", | |
| mime="text/x-sql", | |
| use_container_width=True | |
| ) |