import os import streamlit as st from dotenv import load_dotenv import pandas as pd # Local imports from auth import authenticator from utils import load_table_config, load_uploaded_files, display_table_descriptions # from SmartQuery_GC import SmartQuery from SmartQuery import SmartQuery # If you use chat_ui.py: from chat_ui import display_chat load_dotenv() # ----------------------------------------------------------------------- # Set page config st.set_page_config( page_title="MusoLyze", page_icon="🤖", layout="wide", initial_sidebar_state="expanded", ) # ----------------------------------------------------------------------- # Constants # AUTH_TOKEN = os.environ.get("AUTH_TOKEN") AUTH_TOKEN = st.secrets["AUTH_TOKEN"] ACCESS_JSON_PATH = "access.json" TABLE_CONFIG_PATH = "table_config.json" CSS_PATH = "style.css" with open(CSS_PATH, "r") as f: css_text = f.read() st.markdown(f"", unsafe_allow_html=True) # ----------------------------------------------------------------------- # Initialize Session State if "authenticated" not in st.session_state: st.session_state["authenticated"] = False if "history" not in st.session_state: st.session_state["history"] = [] if "dataframes" not in st.session_state: st.session_state["dataframes"] = [] if "brand" not in st.session_state: st.session_state["brand"] = None # NEW: Track the previous selection of brand, tables, and uploaded file names. if "previous_selection" not in st.session_state: st.session_state["previous_selection"] = { "brand": None, "tables": [], "uploaded_files": [] } # ----------------------------------------------------------------------- # LOGIN PAGE if not st.session_state["authenticated"]: st.markdown('
', unsafe_allow_html=True) st.markdown("## MusoLyze Login") st.write("Please enter your email and authentication token to proceed.") email = st.text_input("Email", placeholder="john.doe@example.com") token = st.text_input("Token", type="password", placeholder="Enter your token") if st.button("Log In"): if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH): st.session_state["authenticated"] = True st.success("Logged in successfully!") st.stop() # Force the script to end; next run user is authenticated. else: st.error("Invalid email or token. Please try again.") st.markdown('
', unsafe_allow_html=True) st.stop() # Stop execution so the rest of the page is not shown. # ----------------------------------------------------------------------- # Main App: Load Data, Show Chat st.title("💬 MusoLyze") # SmartQuery instance sq = SmartQuery() # Load config file for database tables table_config = load_table_config(TABLE_CONFIG_PATH) # Sidebar for file upload and table selection st.sidebar.title("Data Selection") # 1. File upload uploaded_files = st.sidebar.file_uploader( "Upload CSV or Excel files", type=['csv', 'xlsx', 'xls'], accept_multiple_files=True ) # 2. Brand selection brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"]) st.session_state.brand = brand # 3. Table selection db_tables = st.sidebar.multiselect( "Select tables from database", options=list(table_config.keys()), help="Select one or more tables to include in your data." ) # Show table descriptions if user has selected any display_table_descriptions(db_tables, table_config) # 'Load Data' button if st.sidebar.button("Load Data"): # 1) Build the new selection object to compare with previous_selection. new_selection = { "brand": brand, "tables": db_tables, "uploaded_files": [f.name for f in uploaded_files] if uploaded_files else [] } # 2) Compare new selection with old selection; if changed, reset history. if new_selection != st.session_state["previous_selection"]: st.session_state["history"] = [] # 3) Proceed with loading data dataframes = [] # Load from uploaded files if uploaded_files: dataframes.extend(load_uploaded_files(uploaded_files)) # Load dataframes from selected tables if db_tables: for table_name in db_tables: table_info = table_config[table_name] source = table_info["source"] try: if source == 'Snowflake': session = sq.snowflake_connection() df = sq.read_snowflake_table(session, table_name, st.session_state.brand) elif source == 'MySQL': engine = sq.mysql_connection() df = sq.read_mysql_table(engine, table_name, st.session_state.brand) dataframes.append(df) except Exception as e: st.error(f"Error loading table {table_name}: {e}") st.session_state['dataframes'] = dataframes # 4) Update previous_selection in session state st.session_state["previous_selection"] = new_selection st.success("Data loaded successfully!") # -------------------------------------------------------------------------- # If no data is loaded, warn and stop if not st.session_state['dataframes']: st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.") st.stop() # **Always** display top 5 rows of each DataFrame if data is loaded for idx, df in enumerate(st.session_state['dataframes']): st.markdown(f"**Preview of loaded data:**") st.dataframe(df.head(5)) # --- Chat Display Section --- display_chat(st.session_state['history']) # --- User Input Section --- st.markdown("---") with st.form(key="user_query_form"): user_query = st.text_input( "Ask a question about your data:", placeholder="Type your question and press Enter..." ) send_button = st.form_submit_button("Send") if send_button and user_query.strip(): with st.spinner("Analyzing your data..."): try: response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes']) if response['type'] == "dataframe": df = response['value'] st.session_state['history'].append({ 'user': user_query, 'type': 'dataframe', 'bot': df # store the actual DataFrame }) elif response['type'] == "plot": plot_image = response['value'] st.session_state['history'].append({ 'user': user_query, 'type': 'plot', 'bot': plot_image }) else: # string or any other text text_response = response['value'] st.session_state['history'].append({ 'user': user_query, 'type': 'string', 'bot': text_response }) # Rerun to refresh page and clear input st.rerun() except Exception as e: st.error(f"Error: {e}") elif send_button and not user_query.strip(): st.warning("Please enter a question before sending.")