import streamlit as st import pandas as pd import os import sys from dotenv import load_dotenv # Add src to path sys.path.append(os.getcwd()) from src.rag_manager import RAGManager from src.sql_generator import SQLGenerator from src.db_connector import DatabaseConnector # --- 1. CONFIGURATION --- st.set_page_config( page_title="NexusAI | Enterprise Data", page_icon="✨", layout="wide", initial_sidebar_state="collapsed" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # --- 2. INITIALIZATION --- @st.cache_resource def get_core(): load_dotenv() key = os.getenv("GEMINI_API_KEY") return RAGManager(), SQLGenerator(key), DatabaseConnector() try: rag, sql_gen, db = get_core() except Exception as e: st.error(f"System Offline: {e}") st.stop() # --- 3. SIDEBAR --- with st.sidebar: st.markdown("## 🧠 NexusAI") st.caption("Enterprise SQL Agent v2.0") st.divider() if db: st.success("🟢 Database Connected") st.markdown("### 📚 Quick Prompts") prompts = [ "Top 5 employees by salary", "Total sales revenue by Region", "Show me products with low stock", "Which department spends the most?" ] for p in prompts: if st.button(p, use_container_width=True): st.session_state.last_prompt = p if st.button("🗑️ Clear Context", type="primary", use_container_width=True): st.session_state.messages = [] st.rerun() # --- 4. MAIN INTERFACE --- if "messages" not in st.session_state: st.session_state.messages = [] if not st.session_state.messages: st.markdown("""

What can I help you analyze?

Connect to your database and ask questions in plain English.

""", unsafe_allow_html=True) for msg in st.session_state.messages: with st.chat_message(msg["role"], avatar="👤" if msg["role"] == "user" else "✨"): st.markdown(msg["content"]) if "data" in msg: # ✅ FIX: Switched to clean dataframe display st.dataframe(msg["data"], hide_index=True) if "chart" in msg: st.bar_chart(msg["chart"]) if "sql" in msg: with st.expander("🛠️ View Query Logic"): st.code(msg["sql"], language="sql") # Handle Input user_input = st.chat_input("Ask anything...") if "last_prompt" in st.session_state and st.session_state.last_prompt: user_input = st.session_state.last_prompt st.session_state.last_prompt = None if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user", avatar="👤"): st.markdown(user_input) with st.chat_message("assistant", avatar="✨"): status_box = st.empty() status_box.markdown("`⚡ analyzing...`") try: tables = rag.get_relevant_tables(user_input) context = "\n".join(tables) sql = sql_gen.generate_sql(user_input, context) results = db.execute_sql(sql) status_box.empty() if not results: response = "No data found matching that request." st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response, "sql": sql}) else: df = pd.DataFrame(results) df_clean = df.reset_index(drop=True) response = f"Found **{len(df)}** records." st.markdown(response) # ✅ FIX: Updated dataframe display st.dataframe(df_clean, hide_index=True) chart_data = None numeric_cols = df_clean.select_dtypes(include=['number']).columns if not numeric_cols.empty and len(df_clean) > 1: try: non_numeric = df_clean.select_dtypes(exclude=['number']).columns st.markdown("##### 📊 Trends") if not non_numeric.empty: x_axis = non_numeric[0] y_axis = numeric_cols[0] chart_data = df_clean.set_index(x_axis)[y_axis] st.bar_chart(chart_data, color="#7B61FF") else: chart_data = df_clean[numeric_cols[0]] st.bar_chart(chart_data, color="#7B61FF") except Exception: pass st.session_state.messages.append({ "role": "assistant", "content": response, "data": df_clean, "chart": chart_data, "sql": sql }) except Exception as e: status_box.empty() st.error(f"Error: {e}")