Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| import pandas as pd | |
| # Local imports | |
| from auth import authenticator | |
| from utils import load_table_config, load_uploaded_files, display_table_descriptions | |
| # from SmartQuery_GC import SmartQuery | |
| from SmartQuery import SmartQuery | |
| # If you use chat_ui.py: | |
| from chat_ui import display_chat | |
| load_dotenv() | |
| # ----------------------------------------------------------------------- | |
| # Set page config | |
| st.set_page_config( | |
| page_title="MusoLyze", | |
| page_icon="🤖", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Constants | |
| # AUTH_TOKEN = os.environ.get("AUTH_TOKEN") | |
| AUTH_TOKEN = st.secrets["AUTH_TOKEN"] | |
| ACCESS_JSON_PATH = "access.json" | |
| TABLE_CONFIG_PATH = "table_config.json" | |
| CSS_PATH = "style.css" | |
| with open(CSS_PATH, "r") as f: | |
| css_text = f.read() | |
| st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True) | |
| # ----------------------------------------------------------------------- | |
| # Initialize Session State | |
| if "authenticated" not in st.session_state: | |
| st.session_state["authenticated"] = False | |
| if "history" not in st.session_state: | |
| st.session_state["history"] = [] | |
| if "dataframes" not in st.session_state: | |
| st.session_state["dataframes"] = [] | |
| if "brand" not in st.session_state: | |
| st.session_state["brand"] = None | |
| # NEW: Track the previous selection of brand, tables, and uploaded file names. | |
| if "previous_selection" not in st.session_state: | |
| st.session_state["previous_selection"] = { | |
| "brand": None, | |
| "tables": [], | |
| "uploaded_files": [] | |
| } | |
| # ----------------------------------------------------------------------- | |
| # LOGIN PAGE | |
| if not st.session_state["authenticated"]: | |
| st.markdown('<div class="login-container">', unsafe_allow_html=True) | |
| st.markdown("## MusoLyze Login") | |
| st.write("Please enter your email and authentication token to proceed.") | |
| email = st.text_input("Email", placeholder="john.doe@example.com") | |
| token = st.text_input("Token", type="password", placeholder="Enter your token") | |
| if st.button("Log In"): | |
| if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH): | |
| st.session_state["authenticated"] = True | |
| st.success("Logged in successfully!") | |
| st.stop() # Force the script to end; next run user is authenticated. | |
| else: | |
| st.error("Invalid email or token. Please try again.") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.stop() # Stop execution so the rest of the page is not shown. | |
| # ----------------------------------------------------------------------- | |
| # Main App: Load Data, Show Chat | |
| st.title("💬 MusoLyze") | |
| # SmartQuery instance | |
| sq = SmartQuery() | |
| # Load config file for database tables | |
| table_config = load_table_config(TABLE_CONFIG_PATH) | |
| # Sidebar for file upload and table selection | |
| st.sidebar.title("Data Selection") | |
| # 1. File upload | |
| uploaded_files = st.sidebar.file_uploader( | |
| "Upload CSV or Excel files", | |
| type=['csv', 'xlsx', 'xls'], | |
| accept_multiple_files=True | |
| ) | |
| # 2. Brand selection | |
| brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"]) | |
| st.session_state.brand = brand | |
| # 3. Table selection | |
| db_tables = st.sidebar.multiselect( | |
| "Select tables from database", | |
| options=list(table_config.keys()), | |
| help="Select one or more tables to include in your data." | |
| ) | |
| # Show table descriptions if user has selected any | |
| display_table_descriptions(db_tables, table_config) | |
| # 'Load Data' button | |
| if st.sidebar.button("Load Data"): | |
| # 1) Build the new selection object to compare with previous_selection. | |
| new_selection = { | |
| "brand": brand, | |
| "tables": db_tables, | |
| "uploaded_files": [f.name for f in uploaded_files] if uploaded_files else [] | |
| } | |
| # 2) Compare new selection with old selection; if changed, reset history. | |
| if new_selection != st.session_state["previous_selection"]: | |
| st.session_state["history"] = [] | |
| # 3) Proceed with loading data | |
| dataframes = [] | |
| # Load from uploaded files | |
| if uploaded_files: | |
| dataframes.extend(load_uploaded_files(uploaded_files)) | |
| # Load dataframes from selected tables | |
| if db_tables: | |
| for table_name in db_tables: | |
| table_info = table_config[table_name] | |
| source = table_info["source"] | |
| try: | |
| if source == 'Snowflake': | |
| session = sq.snowflake_connection() | |
| df = sq.read_snowflake_table(session, table_name, st.session_state.brand) | |
| elif source == 'MySQL': | |
| engine = sq.mysql_connection() | |
| df = sq.read_mysql_table(engine, table_name, st.session_state.brand) | |
| dataframes.append(df) | |
| except Exception as e: | |
| st.error(f"Error loading table {table_name}: {e}") | |
| st.session_state['dataframes'] = dataframes | |
| # 4) Update previous_selection in session state | |
| st.session_state["previous_selection"] = new_selection | |
| st.success("Data loaded successfully!") | |
| # -------------------------------------------------------------------------- | |
| # If no data is loaded, warn and stop | |
| if not st.session_state['dataframes']: | |
| st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.") | |
| st.stop() | |
| # **Always** display top 5 rows of each DataFrame if data is loaded | |
| for idx, df in enumerate(st.session_state['dataframes']): | |
| st.markdown(f"**Preview of loaded data:**") | |
| st.dataframe(df.head(5)) | |
| # --- Chat Display Section --- | |
| display_chat(st.session_state['history']) | |
| # --- User Input Section --- | |
| st.markdown("---") | |
| with st.form(key="user_query_form"): | |
| user_query = st.text_input( | |
| "Ask a question about your data:", | |
| placeholder="Type your question and press Enter..." | |
| ) | |
| send_button = st.form_submit_button("Send") | |
| if send_button and user_query.strip(): | |
| with st.spinner("Analyzing your data..."): | |
| try: | |
| response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes']) | |
| if response['type'] == "dataframe": | |
| df = response['value'] | |
| st.session_state['history'].append({ | |
| 'user': user_query, | |
| 'type': 'dataframe', | |
| 'bot': df # store the actual DataFrame | |
| }) | |
| elif response['type'] == "plot": | |
| plot_image = response['value'] | |
| st.session_state['history'].append({ | |
| 'user': user_query, | |
| 'type': 'plot', | |
| 'bot': plot_image | |
| }) | |
| else: # string or any other text | |
| text_response = response['value'] | |
| st.session_state['history'].append({ | |
| 'user': user_query, | |
| 'type': 'string', | |
| 'bot': text_response | |
| }) | |
| # Rerun to refresh page and clear input | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| elif send_button and not user_query.strip(): | |
| st.warning("Please enter a question before sending.") | |