Text_to_sql / UI.py
bhavika24's picture
Upload UI.py
65995bf verified
import streamlit as st
from engine import run
from datetime import datetime
# =========================
# PAGE CONFIG
# =========================
st.set_page_config(
page_title="Text-to-SQL AI",
layout="wide"
)
st.title("🧠 Text-to-SQL Assistant")
st.caption("Ask questions in natural language. I’ll generate SQL using your database metadata.")
# =========================
# SESSION STATE
# =========================
if "messages" not in st.session_state:
st.session_state.messages = []
if "transcript" not in st.session_state:
st.session_state.transcript = []
# =========================
# CHAT HISTORY
# =========================
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"], unsafe_allow_html=True)
# =========================
# USER INPUT
# =========================
user_input = st.chat_input("Ask something like: 'Show employees in IT department'")
if user_input:
# Display user message
st.session_state.messages.append(
{"role": "user", "content": user_input}
)
with st.chat_message("user"):
st.markdown(user_input)
# Call engine
with st.spinner("Generating SQL..."):
try:
result = run(user_input)
except Exception as e:
result = {
"status": "error",
"message": str(e)
}
# =========================
# BUILD RESPONSE
# =========================
reply = ""
if result.get("status") == "ok":
if result.get("message"):
reply += f"### ✅ Result\n{result['message']}\n\n"
if result.get("sql"):
reply += "### 🧾 Generated SQL\n"
reply += f"```sql\n{result['sql']}\n```"
else:
reply = f"❌ **Error:** {result.get('message')}"
# Save transcript
st.session_state.transcript.append({
"timestamp": datetime.utcnow().isoformat(),
"question": user_input,
"reply": reply,
"sql": result.get("sql"),
"error": result.get("message") if result.get("status") != "ok" else None
})
# Show assistant reply
st.session_state.messages.append(
{"role": "assistant", "content": reply}
)
with st.chat_message("assistant"):
st.markdown(reply, unsafe_allow_html=True)
# =========================
# DOWNLOAD LOG
# =========================
def download_transcript():
lines = []
for i, t in enumerate(st.session_state.transcript, 1):
lines.append(f"\n--- Query {i} ---")
lines.append(f"Time: {t['timestamp']}")
lines.append(f"Question: {t['question']}")
lines.append(f"Reply:\n{t['reply']}")
if t.get("sql"):
lines.append(f"SQL:\n{t['sql']}")
return "\n".join(lines)
if st.session_state.transcript:
st.divider()
st.download_button(
"⬇️ Download Query Log",
data=download_transcript(),
file_name="text_to_sql_log.txt",
mime="text/plain",
use_container_width=True
)