File size: 6,795 Bytes
8642c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import streamlit as st
import pandas as pd
import os
import sys
from dotenv import load_dotenv

# Add src to path
sys.path.append(os.getcwd())

from src.rag_manager import RAGManager
from src.sql_generator import SQLGenerator
from src.db_connector import DatabaseConnector

# --- 1. CONFIGURATION ---
st.set_page_config(
    page_title="NexusAI | Enterprise Data",
    page_icon="✨",
    layout="wide",
    initial_sidebar_state="collapsed"
)

# Custom CSS
st.markdown("""

<style>

    @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');

    html, body, [class*="css"] { font-family: 'Inter', sans-serif; }

    .stApp { background-color: #0F1117; }

    #MainMenu, footer, header { visibility: hidden; }

    

    .stChatMessage { background-color: transparent !important; border: none !important; }

    

    div[data-testid="stChatMessage"]:nth-child(odd) { flex-direction: row-reverse; }

    div[data-testid="stChatMessage"]:nth-child(odd) .stMarkdown {

        background-color: #2B2D31; color: #E0E0E0;

        border-radius: 18px 18px 4px 18px; padding: 12px 20px;

        text-align: right; margin-left: auto;

    }

    

    div[data-testid="stChatMessage"]:nth-child(even) .stMarkdown {

        background-color: transparent; color: #F0F0F0; padding-left: 10px;

    }

    

    .stChatInput { position: fixed; bottom: 30px; width: 70% !important; left: 50%; transform: translateX(-50%); z-index: 1000; }

    .stTextInput > div > div > input { background-color: #1E2128; color: white; border-radius: 24px; border: 1px solid #363B47; }

    

    div[data-testid="stDataFrame"] { background-color: #161920; border-radius: 10px; padding: 10px; border: 1px solid #30363D; }

    section[data-testid="stSidebar"] { background-color: #0E1015; border-right: 1px solid #222; }

</style>

""", unsafe_allow_html=True)

# --- 2. INITIALIZATION ---
@st.cache_resource
def get_core():
    load_dotenv()
    key = os.getenv("GEMINI_API_KEY")
    return RAGManager(), SQLGenerator(key), DatabaseConnector()

try:
    rag, sql_gen, db = get_core()
except Exception as e:
    st.error(f"System Offline: {e}")
    st.stop()

# --- 3. SIDEBAR ---
with st.sidebar:
    st.markdown("## 🧠 NexusAI")
    st.caption("Enterprise SQL Agent v2.0")
    st.divider()
    
    if db:
        st.success("🟒 Database Connected")
    
    st.markdown("### πŸ“š Quick Prompts")
    prompts = [
        "Top 5 employees by salary",
        "Total sales revenue by Region",
        "Show me products with low stock",
        "Which department spends the most?"
    ]
    
    for p in prompts:
        if st.button(p, use_container_width=True):
            st.session_state.last_prompt = p
            
    if st.button("πŸ—‘οΈ Clear Context", type="primary", use_container_width=True):
        st.session_state.messages = []
        st.rerun()

# --- 4. MAIN INTERFACE ---
if "messages" not in st.session_state:
    st.session_state.messages = []

if not st.session_state.messages:
    st.markdown("""

    <div style="text-align: center; margin-top: 100px;">

        <h1 style="font-size: 3rem; background: -webkit-linear-gradient(#eee, #333); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">

            What can I help you analyze?

        </h1>

        <p style="color: #666;">Connect to your database and ask questions in plain English.</p>

    </div>

    """, unsafe_allow_html=True)

for msg in st.session_state.messages:
    with st.chat_message(msg["role"], avatar="πŸ‘€" if msg["role"] == "user" else "✨"):
        st.markdown(msg["content"])
        
        if "data" in msg:
            # βœ… FIX: Switched to clean dataframe display
            st.dataframe(msg["data"], hide_index=True)
        if "chart" in msg:
            st.bar_chart(msg["chart"])
        if "sql" in msg:
            with st.expander("πŸ› οΈ View Query Logic"):
                st.code(msg["sql"], language="sql")

# Handle Input
user_input = st.chat_input("Ask anything...")

if "last_prompt" in st.session_state and st.session_state.last_prompt:
    user_input = st.session_state.last_prompt
    st.session_state.last_prompt = None

if user_input:
    st.session_state.messages.append({"role": "user", "content": user_input})
    with st.chat_message("user", avatar="πŸ‘€"):
        st.markdown(user_input)

    with st.chat_message("assistant", avatar="✨"):
        status_box = st.empty()
        status_box.markdown("`⚑ analyzing...`")
        
        try:
            tables = rag.get_relevant_tables(user_input)
            context = "\n".join(tables)
            
            sql = sql_gen.generate_sql(user_input, context)
            
            results = db.execute_sql(sql)
            status_box.empty()
            
            if not results:
                response = "No data found matching that request."
                st.markdown(response)
                st.session_state.messages.append({"role": "assistant", "content": response, "sql": sql})
            else:
                df = pd.DataFrame(results)
                df_clean = df.reset_index(drop=True)
                
                response = f"Found **{len(df)}** records."
                st.markdown(response)
                # βœ… FIX: Updated dataframe display
                st.dataframe(df_clean, hide_index=True)
                
                chart_data = None
                numeric_cols = df_clean.select_dtypes(include=['number']).columns
                
                if not numeric_cols.empty and len(df_clean) > 1:
                    try:
                        non_numeric = df_clean.select_dtypes(exclude=['number']).columns
                        st.markdown("##### πŸ“Š Trends")
                        if not non_numeric.empty:
                            x_axis = non_numeric[0]
                            y_axis = numeric_cols[0]
                            chart_data = df_clean.set_index(x_axis)[y_axis]
                            st.bar_chart(chart_data, color="#7B61FF")
                        else:
                            chart_data = df_clean[numeric_cols[0]]
                            st.bar_chart(chart_data, color="#7B61FF")
                    except Exception:
                        pass

                st.session_state.messages.append({
                    "role": "assistant",
                    "content": response,
                    "data": df_clean,
                    "chart": chart_data,
                    "sql": sql
                })

        except Exception as e:
            status_box.empty()
            st.error(f"Error: {e}")