Spaces:
Sleeping
Sleeping
File size: 7,298 Bytes
fe8a467 e29c767 fe8a467 e29c767 fe8a467 b701a18 fe8a467 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import os
import streamlit as st
from dotenv import load_dotenv
import pandas as pd
# Local imports
from auth import authenticator
from utils import load_table_config, load_uploaded_files, display_table_descriptions
# from SmartQuery_GC import SmartQuery
from SmartQuery import SmartQuery
# If you use chat_ui.py:
from chat_ui import display_chat
load_dotenv()
# -----------------------------------------------------------------------
# Set page config
st.set_page_config(
page_title="MusoLyze",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded",
)
# -----------------------------------------------------------------------
# Constants
# AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
AUTH_TOKEN = st.secrets["AUTH_TOKEN"]
ACCESS_JSON_PATH = "access.json"
TABLE_CONFIG_PATH = "table_config.json"
CSS_PATH = "style.css"
with open(CSS_PATH, "r") as f:
css_text = f.read()
st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True)
# -----------------------------------------------------------------------
# Initialize Session State
if "authenticated" not in st.session_state:
st.session_state["authenticated"] = False
if "history" not in st.session_state:
st.session_state["history"] = []
if "dataframes" not in st.session_state:
st.session_state["dataframes"] = []
if "brand" not in st.session_state:
st.session_state["brand"] = None
# NEW: Track the previous selection of brand, tables, and uploaded file names.
if "previous_selection" not in st.session_state:
st.session_state["previous_selection"] = {
"brand": None,
"tables": [],
"uploaded_files": []
}
# -----------------------------------------------------------------------
# LOGIN PAGE
if not st.session_state["authenticated"]:
st.markdown('<div class="login-container">', unsafe_allow_html=True)
st.markdown("## MusoLyze Login")
st.write("Please enter your email and authentication token to proceed.")
email = st.text_input("Email", placeholder="john.doe@example.com")
token = st.text_input("Token", type="password", placeholder="Enter your token")
if st.button("Log In"):
if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH):
st.session_state["authenticated"] = True
st.success("Logged in successfully!")
st.stop() # Force the script to end; next run user is authenticated.
else:
st.error("Invalid email or token. Please try again.")
st.markdown('</div>', unsafe_allow_html=True)
st.stop() # Stop execution so the rest of the page is not shown.
# -----------------------------------------------------------------------
# Main App: Load Data, Show Chat
st.title("💬 MusoLyze")
# SmartQuery instance
sq = SmartQuery()
# Load config file for database tables
table_config = load_table_config(TABLE_CONFIG_PATH)
# Sidebar for file upload and table selection
st.sidebar.title("Data Selection")
# 1. File upload
uploaded_files = st.sidebar.file_uploader(
"Upload CSV or Excel files",
type=['csv', 'xlsx', 'xls'],
accept_multiple_files=True
)
# 2. Brand selection
brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"])
st.session_state.brand = brand
# 3. Table selection
db_tables = st.sidebar.multiselect(
"Select tables from database",
options=list(table_config.keys()),
help="Select one or more tables to include in your data."
)
# Show table descriptions if user has selected any
display_table_descriptions(db_tables, table_config)
# 'Load Data' button
if st.sidebar.button("Load Data"):
# 1) Build the new selection object to compare with previous_selection.
new_selection = {
"brand": brand,
"tables": db_tables,
"uploaded_files": [f.name for f in uploaded_files] if uploaded_files else []
}
# 2) Compare new selection with old selection; if changed, reset history.
if new_selection != st.session_state["previous_selection"]:
st.session_state["history"] = []
# 3) Proceed with loading data
dataframes = []
# Load from uploaded files
if uploaded_files:
dataframes.extend(load_uploaded_files(uploaded_files))
# Load dataframes from selected tables
if db_tables:
for table_name in db_tables:
table_info = table_config[table_name]
source = table_info["source"]
try:
if source == 'Snowflake':
session = sq.snowflake_connection()
df = sq.read_snowflake_table(session, table_name, st.session_state.brand)
elif source == 'MySQL':
engine = sq.mysql_connection()
df = sq.read_mysql_table(engine, table_name, st.session_state.brand)
dataframes.append(df)
except Exception as e:
st.error(f"Error loading table {table_name}: {e}")
st.session_state['dataframes'] = dataframes
# 4) Update previous_selection in session state
st.session_state["previous_selection"] = new_selection
st.success("Data loaded successfully!")
# --------------------------------------------------------------------------
# If no data is loaded, warn and stop
if not st.session_state['dataframes']:
st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.")
st.stop()
# **Always** display top 5 rows of each DataFrame if data is loaded
for idx, df in enumerate(st.session_state['dataframes']):
st.markdown(f"**Preview of loaded data:**")
st.dataframe(df.head(5))
# --- Chat Display Section ---
display_chat(st.session_state['history'])
# --- User Input Section ---
st.markdown("---")
with st.form(key="user_query_form"):
user_query = st.text_input(
"Ask a question about your data:",
placeholder="Type your question and press Enter..."
)
send_button = st.form_submit_button("Send")
if send_button and user_query.strip():
with st.spinner("Analyzing your data..."):
try:
response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes'])
if response['type'] == "dataframe":
df = response['value']
st.session_state['history'].append({
'user': user_query,
'type': 'dataframe',
'bot': df # store the actual DataFrame
})
elif response['type'] == "plot":
plot_image = response['value']
st.session_state['history'].append({
'user': user_query,
'type': 'plot',
'bot': plot_image
})
else: # string or any other text
text_response = response['value']
st.session_state['history'].append({
'user': user_query,
'type': 'string',
'bot': text_response
})
# Rerun to refresh page and clear input
st.rerun()
except Exception as e:
st.error(f"Error: {e}")
elif send_button and not user_query.strip():
st.warning("Please enter a question before sending.")
|