import streamlit as st
import pandas as pd
import os
import sys
from dotenv import load_dotenv
# Add src to path
sys.path.append(os.getcwd())
from src.rag_manager import RAGManager
from src.sql_generator import SQLGenerator
from src.db_connector import DatabaseConnector
# --- 1. CONFIGURATION ---
st.set_page_config(
page_title="NexusAI | Enterprise Data",
page_icon="✨",
layout="wide",
initial_sidebar_state="collapsed"
)
# Custom CSS
st.markdown("""
""", unsafe_allow_html=True)
# --- 2. INITIALIZATION ---
@st.cache_resource
def get_core():
load_dotenv()
key = os.getenv("GEMINI_API_KEY")
return RAGManager(), SQLGenerator(key), DatabaseConnector()
try:
rag, sql_gen, db = get_core()
except Exception as e:
st.error(f"System Offline: {e}")
st.stop()
# --- 3. SIDEBAR ---
with st.sidebar:
st.markdown("## 🧠 NexusAI")
st.caption("Enterprise SQL Agent v2.0")
st.divider()
if db:
st.success("🟢 Database Connected")
st.markdown("### 📚 Quick Prompts")
prompts = [
"Top 5 employees by salary",
"Total sales revenue by Region",
"Show me products with low stock",
"Which department spends the most?"
]
for p in prompts:
if st.button(p, use_container_width=True):
st.session_state.last_prompt = p
if st.button("🗑️ Clear Context", type="primary", use_container_width=True):
st.session_state.messages = []
st.rerun()
# --- 4. MAIN INTERFACE ---
if "messages" not in st.session_state:
st.session_state.messages = []
if not st.session_state.messages:
st.markdown("""
What can I help you analyze?
Connect to your database and ask questions in plain English.
""", unsafe_allow_html=True)
for msg in st.session_state.messages:
with st.chat_message(msg["role"], avatar="👤" if msg["role"] == "user" else "✨"):
st.markdown(msg["content"])
if "data" in msg:
# ✅ FIX: Switched to clean dataframe display
st.dataframe(msg["data"], hide_index=True)
if "chart" in msg:
st.bar_chart(msg["chart"])
if "sql" in msg:
with st.expander("🛠️ View Query Logic"):
st.code(msg["sql"], language="sql")
# Handle Input
user_input = st.chat_input("Ask anything...")
if "last_prompt" in st.session_state and st.session_state.last_prompt:
user_input = st.session_state.last_prompt
st.session_state.last_prompt = None
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user", avatar="👤"):
st.markdown(user_input)
with st.chat_message("assistant", avatar="✨"):
status_box = st.empty()
status_box.markdown("`⚡ analyzing...`")
try:
tables = rag.get_relevant_tables(user_input)
context = "\n".join(tables)
sql = sql_gen.generate_sql(user_input, context)
results = db.execute_sql(sql)
status_box.empty()
if not results:
response = "No data found matching that request."
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response, "sql": sql})
else:
df = pd.DataFrame(results)
df_clean = df.reset_index(drop=True)
response = f"Found **{len(df)}** records."
st.markdown(response)
# ✅ FIX: Updated dataframe display
st.dataframe(df_clean, hide_index=True)
chart_data = None
numeric_cols = df_clean.select_dtypes(include=['number']).columns
if not numeric_cols.empty and len(df_clean) > 1:
try:
non_numeric = df_clean.select_dtypes(exclude=['number']).columns
st.markdown("##### 📊 Trends")
if not non_numeric.empty:
x_axis = non_numeric[0]
y_axis = numeric_cols[0]
chart_data = df_clean.set_index(x_axis)[y_axis]
st.bar_chart(chart_data, color="#7B61FF")
else:
chart_data = df_clean[numeric_cols[0]]
st.bar_chart(chart_data, color="#7B61FF")
except Exception:
pass
st.session_state.messages.append({
"role": "assistant",
"content": response,
"data": df_clean,
"chart": chart_data,
"sql": sql
})
except Exception as e:
status_box.empty()
st.error(f"Error: {e}")