import streamlit as st from engine import run from datetime import datetime # ========================= # PAGE CONFIG # ========================= st.set_page_config( page_title="Text-to-SQL AI", layout="wide" ) st.title("🧠 Text-to-SQL Assistant") st.caption("Ask questions in natural language. I’ll generate SQL using your database metadata.") # ========================= # SESSION STATE # ========================= if "messages" not in st.session_state: st.session_state.messages = [] if "transcript" not in st.session_state: st.session_state.transcript = [] # ========================= # CHAT HISTORY # ========================= for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"], unsafe_allow_html=True) # ========================= # USER INPUT # ========================= user_input = st.chat_input("Ask something like: 'Show employees in IT department'") if user_input: # Display user message st.session_state.messages.append( {"role": "user", "content": user_input} ) with st.chat_message("user"): st.markdown(user_input) # Call engine with st.spinner("Generating SQL..."): try: result = run(user_input) except Exception as e: result = { "status": "error", "message": str(e) } # ========================= # BUILD RESPONSE # ========================= reply = "" if result.get("status") == "ok": if result.get("message"): reply += f"### ✅ Result\n{result['message']}\n\n" if result.get("sql"): reply += "### 🧾 Generated SQL\n" reply += f"```sql\n{result['sql']}\n```" else: reply = f"❌ **Error:** {result.get('message')}" # Save transcript st.session_state.transcript.append({ "timestamp": datetime.utcnow().isoformat(), "question": user_input, "reply": reply, "sql": result.get("sql"), "error": result.get("message") if result.get("status") != "ok" else None }) # Show assistant reply st.session_state.messages.append( {"role": "assistant", "content": reply} ) with st.chat_message("assistant"): st.markdown(reply, unsafe_allow_html=True) # ========================= # DOWNLOAD LOG # ========================= def download_transcript(): lines = [] for i, t in enumerate(st.session_state.transcript, 1): lines.append(f"\n--- Query {i} ---") lines.append(f"Time: {t['timestamp']}") lines.append(f"Question: {t['question']}") lines.append(f"Reply:\n{t['reply']}") if t.get("sql"): lines.append(f"SQL:\n{t['sql']}") return "\n".join(lines) if st.session_state.transcript: st.divider() st.download_button( "⬇️ Download Query Log", data=download_transcript(), file_name="text_to_sql_log.txt", mime="text/plain", use_container_width=True )