File size: 7,298 Bytes
fe8a467
e29c767
fe8a467
 
e29c767
fe8a467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b701a18
fe8a467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
import streamlit as st
from dotenv import load_dotenv
import pandas as pd

# Local imports
from auth import authenticator
from utils import load_table_config, load_uploaded_files, display_table_descriptions
# from SmartQuery_GC import SmartQuery
from SmartQuery import SmartQuery
# If you use chat_ui.py:
from chat_ui import display_chat

load_dotenv()

# -----------------------------------------------------------------------
# Set page config
st.set_page_config(
    page_title="MusoLyze",
    page_icon="🤖",
    layout="wide",
    initial_sidebar_state="expanded",
)

# -----------------------------------------------------------------------
# Constants
# AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
AUTH_TOKEN = st.secrets["AUTH_TOKEN"]
ACCESS_JSON_PATH = "access.json"
TABLE_CONFIG_PATH = "table_config.json"
CSS_PATH = "style.css"

with open(CSS_PATH, "r") as f:
    css_text = f.read()
    st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True)

# -----------------------------------------------------------------------
# Initialize Session State
if "authenticated" not in st.session_state:
    st.session_state["authenticated"] = False
if "history" not in st.session_state:
    st.session_state["history"] = []
if "dataframes" not in st.session_state:
    st.session_state["dataframes"] = []
if "brand" not in st.session_state:
    st.session_state["brand"] = None

# NEW: Track the previous selection of brand, tables, and uploaded file names.
if "previous_selection" not in st.session_state:
    st.session_state["previous_selection"] = {
        "brand": None,
        "tables": [],
        "uploaded_files": []
    }

# -----------------------------------------------------------------------
# LOGIN PAGE
if not st.session_state["authenticated"]:
    st.markdown('<div class="login-container">', unsafe_allow_html=True)
    st.markdown("## MusoLyze Login")
    st.write("Please enter your email and authentication token to proceed.")

    email = st.text_input("Email", placeholder="john.doe@example.com")
    token = st.text_input("Token", type="password", placeholder="Enter your token")

    if st.button("Log In"):
        if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH):
            st.session_state["authenticated"] = True
            st.success("Logged in successfully!")
            st.stop()  # Force the script to end; next run user is authenticated.
        else:
            st.error("Invalid email or token. Please try again.")

    st.markdown('</div>', unsafe_allow_html=True)
    st.stop()  # Stop execution so the rest of the page is not shown.

# -----------------------------------------------------------------------
# Main App: Load Data, Show Chat
st.title("💬 MusoLyze")

# SmartQuery instance
sq = SmartQuery()

# Load config file for database tables
table_config = load_table_config(TABLE_CONFIG_PATH)

# Sidebar for file upload and table selection
st.sidebar.title("Data Selection")

# 1. File upload
uploaded_files = st.sidebar.file_uploader(
    "Upload CSV or Excel files",
    type=['csv', 'xlsx', 'xls'],
    accept_multiple_files=True
)

# 2. Brand selection
brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"])
st.session_state.brand = brand

# 3. Table selection
db_tables = st.sidebar.multiselect(
    "Select tables from database",
    options=list(table_config.keys()),
    help="Select one or more tables to include in your data."
)

# Show table descriptions if user has selected any
display_table_descriptions(db_tables, table_config)

# 'Load Data' button
if st.sidebar.button("Load Data"):
    # 1) Build the new selection object to compare with previous_selection.
    new_selection = {
        "brand": brand,
        "tables": db_tables,
        "uploaded_files": [f.name for f in uploaded_files] if uploaded_files else []
    }

    # 2) Compare new selection with old selection; if changed, reset history.
    if new_selection != st.session_state["previous_selection"]:
        st.session_state["history"] = []

    # 3) Proceed with loading data
    dataframes = []

    # Load from uploaded files
    if uploaded_files:
        dataframes.extend(load_uploaded_files(uploaded_files))

    # Load dataframes from selected tables
    if db_tables:
        for table_name in db_tables:
            table_info = table_config[table_name]
            source = table_info["source"]
            try:
                if source == 'Snowflake':
                    session = sq.snowflake_connection()
                    df = sq.read_snowflake_table(session, table_name, st.session_state.brand)
                elif source == 'MySQL':
                    engine = sq.mysql_connection()
                    df = sq.read_mysql_table(engine, table_name, st.session_state.brand)
                dataframes.append(df)
            except Exception as e:
                st.error(f"Error loading table {table_name}: {e}")

    st.session_state['dataframes'] = dataframes

    # 4) Update previous_selection in session state
    st.session_state["previous_selection"] = new_selection

    st.success("Data loaded successfully!")

# --------------------------------------------------------------------------
# If no data is loaded, warn and stop
if not st.session_state['dataframes']:
    st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.")
    st.stop()

# **Always** display top 5 rows of each DataFrame if data is loaded
for idx, df in enumerate(st.session_state['dataframes']):
    st.markdown(f"**Preview of loaded data:**")
    st.dataframe(df.head(5))

# --- Chat Display Section ---
display_chat(st.session_state['history'])

# --- User Input Section ---
st.markdown("---")

with st.form(key="user_query_form"):
    user_query = st.text_input(
        "Ask a question about your data:",
        placeholder="Type your question and press Enter..."
    )
    send_button = st.form_submit_button("Send")

if send_button and user_query.strip():
    with st.spinner("Analyzing your data..."):
        try:
            response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes'])

            if response['type'] == "dataframe":
                df = response['value']
                st.session_state['history'].append({
                    'user': user_query,
                    'type': 'dataframe',
                    'bot': df  # store the actual DataFrame
                })
            elif response['type'] == "plot":
                plot_image = response['value']
                st.session_state['history'].append({
                    'user': user_query,
                    'type': 'plot',
                    'bot': plot_image
                })
            else:  # string or any other text
                text_response = response['value']
                st.session_state['history'].append({
                    'user': user_query,
                    'type': 'string',
                    'bot': text_response
                })

            # Rerun to refresh page and clear input
            st.rerun()

        except Exception as e:
            st.error(f"Error: {e}")

elif send_button and not user_query.strip():
    st.warning("Please enter a question before sending.")