Danialebrat commited on
Commit
fe8a467
·
1 Parent(s): 806f16d

adding codes and files

Browse files
Files changed (13) hide show
  1. .dockerignore +5 -0
  2. .gitignore +5 -0
  3. README.md +2 -14
  4. SmartQuery.py +190 -0
  5. SmartQuery_GC.py +190 -0
  6. access.json +3 -0
  7. app.py +210 -2
  8. auth.py +14 -0
  9. chat_ui.py +55 -0
  10. local_app.py +211 -0
  11. style.css +47 -0
  12. table_config.json +23 -0
  13. utils.py +35 -0
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Ignore the .streamlit directory and its contents
2
+ .streamlit/
3
+
4
+ # Ignore the .env file
5
+ .env
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Ignore the .streamlit directory and its contents
2
+ .streamlit/
3
+
4
+ # Ignore the .env file
5
+ .env
README.md CHANGED
@@ -1,14 +1,2 @@
1
- ---
2
- title: Musolyze
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.43.2
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Analyzing Musora databases using natural language
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # SmartQuery
2
+ Ask questions from your data in natural language
 
 
 
 
 
 
 
 
 
 
 
 
SmartQuery.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pandasai.llm import OpenAI
2
+ from pandasai import Agent
3
+ from pandasai import SmartDataframe, SmartDatalake
4
+ from pandasai.responses.response_parser import ResponseParser
5
+ from pandasai.responses.streamlit_response import StreamlitResponse
6
+ from snowflake.snowpark import Session
7
+ import json
8
+ import pandas as pd
9
+ from sqlalchemy import create_engine
10
+ import os
11
+ from dotenv import load_dotenv
12
+ import streamlit as st
13
+ load_dotenv()
14
+
15
+ # -----------------------------------------------------------------------
16
+ key = st.secrets["pandasai"]["PANDASAI_API_KEY"]
17
+ os.environ['PANDASAI_API_KEY'] = key
18
+ openai_llm = OpenAI(
19
+ api_token=st.secrets["openai"]["OPENAI_API"]
20
+ )
21
+
22
+ # -----------------------------------------------------------------------
23
+ # -----------------------------------------------------------------------
24
+ class SmartQuery:
25
+ """
26
+ class for interacting with dataframes using Natural Language
27
+ """
28
+
29
+ def __init__(self):
30
+ with open("table_config.json", "r") as f:
31
+ self.config = json.load(f)
32
+
33
+ def perform_query_on_dataframes(self, query, *dataframes, response_format=None):
34
+ """
35
+ Performs a user-defined query on given pandas DataFrames using PandasAI.
36
+
37
+ Parameters:
38
+ - query (str): The user's query or instruction.
39
+ - *dataframes (pd.DataFrame): Any number of pandas DataFrames.
40
+
41
+ Returns:
42
+ - The result of the query executed by PandasAI.
43
+ """
44
+
45
+ dataframe_list = list(dataframes)
46
+ num_dataframes = len(dataframe_list)
47
+
48
+ config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser}
49
+
50
+ if num_dataframes == 1:
51
+ result = self.query_single_dataframe(query, dataframe_list[0], config)
52
+
53
+ else:
54
+ result = self.query_multiple_dataframes(query, dataframe_list, config)
55
+
56
+ return result
57
+
58
+ def query_single_dataframe(self, query, dataframe, config):
59
+
60
+ agent = Agent(dataframe, config=config)
61
+ response = agent.chat(query)
62
+
63
+ return response
64
+
65
+ def query_multiple_dataframes(self, query, dataframe_list, config):
66
+
67
+ agent = SmartDatalake(dataframe_list, config=config)
68
+ response = agent.chat(query)
69
+
70
+ return response
71
+
72
+ # -----------------------------------------------------------------------
73
+ def snowflake_connection(self):
74
+ """
75
+ setting snowflake connection
76
+ :return:
77
+ """
78
+
79
+ conn = {
80
+ "user": os.environ.get("snowflake_user"),
81
+ "password": os.environ.get("snowflake_password"),
82
+ "account": os.environ.get("snowflake_account"),
83
+ "role": os.environ.get("snowflake_role"),
84
+ "database": os.environ.get("snowflake_database"),
85
+ "warehouse": os.environ.get("snowflake_warehouse"),
86
+ "schema": os.environ.get("snowflake_schema")
87
+ }
88
+ try:
89
+ session = Session.builder.configs(conn).create()
90
+ return session
91
+ except Exception as e:
92
+ print(f"Error creating Snowflake session: {e}")
93
+ raise e
94
+
95
+ # ----------------------------------------------------------------------------------------------------
96
+ def read_snowflake_table(self, session, table_name, brand):
97
+ """
98
+ reading tables from snowflake
99
+ :param dataframe:
100
+ :return:
101
+ """
102
+ query = self._get_query(table_name, brand)
103
+
104
+ # Connect to Snowflake
105
+ try:
106
+ dataframe = session.sql(query).to_pandas()
107
+ dataframe.columns = dataframe.columns.str.lower()
108
+ print(f"reading content table successfully")
109
+ return dataframe
110
+ except Exception as e:
111
+ print(f"Error in reading table: {e}")
112
+
113
+ # ----------------------------------------------------------------------------------------------------
114
+ def _get_query(self, table_name: str, brand: str) -> str:
115
+ # Retrieve the base query template for the given table name
116
+ base_query = self.config[table_name]["query"]
117
+
118
+ # Insert the brand condition into the query
119
+ query = base_query.format(brand=brand.lower())
120
+
121
+ return query
122
+
123
+ # ----------------------------------------------------------------------------------------------------
124
+ def mysql_connection(self):
125
+
126
+ # Setting up the MySQL connection parameters
127
+ user = os.environ.get("mysql_user")
128
+ password = os.environ.get("mysql_password")
129
+ host = os.environ.get("mysql_source")
130
+ database = os.environ.get("mysql_schema")
131
+
132
+ try:
133
+ engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}")
134
+ return engine
135
+ except Exception as e:
136
+ print(f"Error creating MySQL engine: {e}")
137
+ raise e
138
+
139
+ # ----------------------------------------------------------------------------------------------------
140
+ def read_mysql_table(self, engine, table_name, brand):
141
+
142
+ query = self._get_query(table_name, brand)
143
+
144
+ with engine.connect() as conn:
145
+ dataframe = pd.read_sql_query(query, conn)
146
+
147
+ # Convert all column names to lowercase if not
148
+ dataframe.columns = dataframe.columns.str.lower()
149
+
150
+ return dataframe
151
+
152
+
153
+ # ----------------------------------------------------------------------------------------------------
154
+ # ----------------------------------------------------------------------------------------------------
155
+ class OutputParser(ResponseParser):
156
+ def __init__(self, context) -> None:
157
+ super().__init__(context)
158
+
159
+ def parse(self, result):
160
+ return result
161
+
162
+ # ----------------------------------------------------------------------------------------------------
163
+
164
+
165
+ if __name__ == "__main__":
166
+
167
+ query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents"
168
+ query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes"
169
+ query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? "
170
+
171
+ dataframe_path = "data/recent_comment_test.csv"
172
+
173
+ dataframe1 = pd.read_csv(dataframe_path)
174
+
175
+ sq = SmartQuery()
176
+ interactions_path = "DBT_ANALYTICS.CORE.FCT_CONTENT_INTERACTIONS"
177
+ content_path = "DBT_ANALYTICS.CORE.DIM_CONTENT"
178
+ session = sq.snowflake_connection()
179
+
180
+
181
+ interactions_df = sq.read_snowflake_table(session, table_name="interactions", brand="drumeo")
182
+ content_df = sq.read_snowflake_table(session, table_name="contents", brand="drumeo")
183
+
184
+ # single dataframe
185
+ # result = sq.perform_query_on_dataframes(query, dataframe, response_format="dataframe")
186
+
187
+ # multiple dataframe
188
+ result = sq.perform_query_on_dataframes(query_multi, interactions_df, content_df, response_format="dataframe")
189
+
190
+ print(result)
SmartQuery_GC.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this class will use env variables to read secrets from google cloud
2
+ from pandasai import Agent
3
+ from pandasai import SmartDataframe, SmartDatalake
4
+ from pandasai.responses.response_parser import ResponseParser
5
+ from pandasai.llm.openai import OpenAI
6
+ from pandasai.responses.streamlit_response import StreamlitResponse
7
+ import pymysql
8
+ from pandasai.connectors import PandasConnector
9
+ from snowflake.snowpark import Session
10
+ import json
11
+ import pandas as pd
12
+ from sqlalchemy import create_engine
13
+ import os
14
+ import streamlit as st
15
+ from dotenv import load_dotenv
16
+
17
+ load_dotenv()
18
+ import datetime
19
+
20
+ # -----------------------------------------------------------------------
21
+ key = os.environ.get("PANDASAI_API_KEY")
22
+ os.environ['PANDASAI_API_KEY'] = key
23
+
24
+ # openai_llm = OpenAI(api_key=os.environ.get("OPENAI_API"))
25
+
26
+ openai_llm = OpenAI(
27
+ api_token=os.environ.get("OPENAI_API")
28
+ )
29
+
30
+
31
+ # -----------------------------------------------------------------------
32
+ class SmartQuery:
33
+ """
34
+ class for interacting with dataframes using Natural Language
35
+ """
36
+
37
+ def __init__(self):
38
+ with open("table_config.json", "r") as f:
39
+ self.config = json.load(f)
40
+
41
+ def perform_query_on_dataframes(self, query, *dataframes, response_format=None):
42
+ """
43
+ Performs a user-defined query on given pandas DataFrames using PandasAI.
44
+
45
+ Parameters:
46
+ - query (str): The user's query or instruction.
47
+ - *dataframes (pd.DataFrame): Any number of pandas DataFrames.
48
+
49
+ Returns:
50
+ - The result of the query executed by PandasAI.
51
+ """
52
+
53
+ dataframe_list = list(dataframes)
54
+ num_dataframes = len(dataframe_list)
55
+
56
+ config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser}
57
+
58
+ if num_dataframes == 1:
59
+ result = self.query_single_dataframe(query, dataframe_list[0], config)
60
+
61
+ else:
62
+ result = self.query_multiple_dataframes(query, dataframe_list, config)
63
+
64
+ return result
65
+
66
+ def query_single_dataframe(self, query, dataframe, config):
67
+
68
+ agent = Agent(dataframe, config=config)
69
+ response = agent.chat(query)
70
+
71
+ return response
72
+
73
+ def query_multiple_dataframes(self, query, dataframe_list, config):
74
+
75
+ agent = SmartDatalake(dataframe_list, config=config)
76
+ response = agent.chat(query)
77
+
78
+ return response
79
+
80
+ # -----------------------------------------------------------------------
81
+ def snowflake_connection(self):
82
+ """
83
+ setting snowflake connection
84
+ :return:
85
+ """
86
+
87
+ conn = {
88
+ "user": os.environ.get("snowflake_user"),
89
+ "password": os.environ.get("snowflake_password"),
90
+ "account": os.environ.get("snowflake_account"),
91
+ "role": os.environ.get("snowflake_role"),
92
+ "database": os.environ.get("snowflake_database"),
93
+ "warehouse": os.environ.get("snowflake_warehouse"),
94
+ "schema": os.environ.get("snowflake_schema")
95
+ }
96
+ try:
97
+ session = Session.builder.configs(conn).create()
98
+ return session
99
+ except Exception as e:
100
+ print(f"Error creating Snowflake session: {e}")
101
+ raise e
102
+
103
+ # ----------------------------------------------------------------------------------------------------
104
+ def read_snowflake_table(self, session, table_name, brand):
105
+ """
106
+ reading tables from snowflake
107
+ :param dataframe:
108
+ :return:
109
+ """
110
+ query = self._get_query(table_name, brand)
111
+
112
+ # Connect to Snowflake
113
+ try:
114
+ dataframe = session.sql(query).to_pandas()
115
+ dataframe.columns = dataframe.columns.str.lower()
116
+ print(f"reading content table successfully")
117
+ return dataframe
118
+ except Exception as e:
119
+ print(f"Error in reading table: {e}")
120
+
121
+ # ----------------------------------------------------------------------------------------------------
122
+ def _get_query(self, table_name: str, brand: str) -> str:
123
+ # Retrieve the base query template for the given table name
124
+ base_query = self.config[table_name]["query"]
125
+
126
+ # Insert the brand condition into the query
127
+ query = base_query.format(brand=brand.lower())
128
+
129
+ return query
130
+
131
+ # ----------------------------------------------------------------------------------------------------
132
+ def mysql_connection(self):
133
+
134
+ # Setting up the MySQL connection parameters
135
+ user = os.environ.get("mysql_user")
136
+ password = os.environ.get("mysql_password")
137
+ host = os.environ.get("mysql_source")
138
+ database = os.environ.get("mysql_schema")
139
+
140
+ try:
141
+ engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}")
142
+ return engine
143
+ except Exception as e:
144
+ print(f"Error creating MySQL engine: {e}")
145
+ raise e
146
+
147
+ # ----------------------------------------------------------------------------------------------------
148
+ def read_mysql_table(self, engine, table_name, brand):
149
+
150
+ query = self._get_query(table_name, brand)
151
+
152
+ with engine.connect() as conn:
153
+ dataframe = pd.read_sql_query(query, conn)
154
+
155
+ # Convert all column names to lowercase if not
156
+ dataframe.columns = dataframe.columns.str.lower()
157
+
158
+ return dataframe
159
+
160
+
161
+ # ----------------------------------------------------------------------------------------------------
162
+ # ----------------------------------------------------------------------------------------------------
163
+ class OutputParser(ResponseParser):
164
+ def __init__(self, context) -> None:
165
+ super().__init__(context)
166
+
167
+ def parse(self, result):
168
+ return result
169
+
170
+ # ----------------------------------------------------------------------------------------------------
171
+
172
+
173
+ if __name__ == "__main__":
174
+ # query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents"
175
+ # query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes"
176
+ # query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? "
177
+ # query = "how many users do we have with 0 experience level?"
178
+ query = "select song content_type that have difficulty range of 0-3?"
179
+ #
180
+ # dataframe_path = "data/recent_comment_test.csv"
181
+ #
182
+ # dataframe1 = pd.read_csv(dataframe_path)
183
+ #
184
+ sq = SmartQuery()
185
+ session = sq.snowflake_connection()
186
+ dataframe = sq.read_snowflake_table(session, table_name="contents", brand="drumeo")
187
+
188
+ result = sq.perform_query_on_dataframes(query, dataframe)
189
+
190
+ print(result)
access.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "email": ["danial@musora.com", "danial.ebrat@gmail.com"]
3
+ }
app.py CHANGED
@@ -1,4 +1,212 @@
 
1
  import streamlit as st
 
 
2
 
3
- x = st.slider('Select a value PLEASE')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import streamlit as st
3
+ from dotenv import load_dotenv
4
+ import pandas as pd
5
 
6
+ # Local imports
7
+ from auth import authenticator
8
+ from utils import load_table_config, load_uploaded_files, display_table_descriptions
9
+ # from SmartQuery_GC import SmartQuery
10
+ from SmartQuery import SmartQuery
11
+ # If you use chat_ui.py:
12
+ from chat_ui import display_chat
13
+
14
+ load_dotenv()
15
+
16
+ # -----------------------------------------------------------------------
17
+ # Set page config
18
+ st.set_page_config(
19
+ page_title="MusoLyze",
20
+ page_icon="🤖",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded",
23
+ )
24
+
25
+ # -----------------------------------------------------------------------
26
+ # Constants
27
+ # AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
28
+ AUTH_TOKEN = st.secrets["token"]["AUTH_TOKEN"]
29
+ ACCESS_JSON_PATH = "access.json"
30
+ TABLE_CONFIG_PATH = "table_config.json"
31
+ CSS_PATH = "style.css"
32
+
33
+ with open(CSS_PATH, "r") as f:
34
+ css_text = f.read()
35
+ st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True)
36
+
37
+ # -----------------------------------------------------------------------
38
+ # Initialize Session State
39
+ if "authenticated" not in st.session_state:
40
+ st.session_state["authenticated"] = False
41
+ if "history" not in st.session_state:
42
+ st.session_state["history"] = []
43
+ if "dataframes" not in st.session_state:
44
+ st.session_state["dataframes"] = []
45
+ if "brand" not in st.session_state:
46
+ st.session_state["brand"] = None
47
+
48
+ # NEW: Track the previous selection of brand, tables, and uploaded file names.
49
+ if "previous_selection" not in st.session_state:
50
+ st.session_state["previous_selection"] = {
51
+ "brand": None,
52
+ "tables": [],
53
+ "uploaded_files": []
54
+ }
55
+
56
+ # -----------------------------------------------------------------------
57
+ # LOGIN PAGE
58
+ if not st.session_state["authenticated"]:
59
+ st.markdown('<div class="login-container">', unsafe_allow_html=True)
60
+ st.markdown("## MusoLyze Login")
61
+ st.write("Please enter your email and authentication token to proceed.")
62
+
63
+ email = st.text_input("Email", placeholder="john.doe@example.com")
64
+ token = st.text_input("Token", type="password", placeholder="Enter your token")
65
+
66
+ if st.button("Log In"):
67
+ if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH):
68
+ st.session_state["authenticated"] = True
69
+ st.success("Logged in successfully!")
70
+ st.stop() # Force the script to end; next run user is authenticated.
71
+ else:
72
+ st.error("Invalid email or token. Please try again.")
73
+
74
+ st.markdown('</div>', unsafe_allow_html=True)
75
+ st.stop() # Stop execution so the rest of the page is not shown.
76
+
77
+ # -----------------------------------------------------------------------
78
+ # Main App: Load Data, Show Chat
79
+ st.title("💬 MusoLyze")
80
+
81
+ # SmartQuery instance
82
+ sq = SmartQuery()
83
+
84
+ # Load config file for database tables
85
+ table_config = load_table_config(TABLE_CONFIG_PATH)
86
+
87
+ # Sidebar for file upload and table selection
88
+ st.sidebar.title("Data Selection")
89
+
90
+ # 1. File upload
91
+ uploaded_files = st.sidebar.file_uploader(
92
+ "Upload CSV or Excel files",
93
+ type=['csv', 'xlsx', 'xls'],
94
+ accept_multiple_files=True
95
+ )
96
+
97
+ # 2. Brand selection
98
+ brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"])
99
+ st.session_state.brand = brand
100
+
101
+ # 3. Table selection
102
+ db_tables = st.sidebar.multiselect(
103
+ "Select tables from database",
104
+ options=list(table_config.keys()),
105
+ help="Select one or more tables to include in your data."
106
+ )
107
+
108
+ # Show table descriptions if user has selected any
109
+ display_table_descriptions(db_tables, table_config)
110
+
111
+ # 'Load Data' button
112
+ if st.sidebar.button("Load Data"):
113
+ # 1) Build the new selection object to compare with previous_selection.
114
+ new_selection = {
115
+ "brand": brand,
116
+ "tables": db_tables,
117
+ "uploaded_files": [f.name for f in uploaded_files] if uploaded_files else []
118
+ }
119
+
120
+ # 2) Compare new selection with old selection; if changed, reset history.
121
+ if new_selection != st.session_state["previous_selection"]:
122
+ st.session_state["history"] = []
123
+
124
+ # 3) Proceed with loading data
125
+ dataframes = []
126
+
127
+ # Load from uploaded files
128
+ if uploaded_files:
129
+ dataframes.extend(load_uploaded_files(uploaded_files))
130
+
131
+ # Load dataframes from selected tables
132
+ if db_tables:
133
+ for table_name in db_tables:
134
+ table_info = table_config[table_name]
135
+ source = table_info["source"]
136
+ try:
137
+ if source == 'Snowflake':
138
+ session = sq.snowflake_connection()
139
+ df = sq.read_snowflake_table(session, table_name, st.session_state.brand)
140
+ elif source == 'MySQL':
141
+ engine = sq.mysql_connection()
142
+ df = sq.read_mysql_table(engine, table_name, st.session_state.brand)
143
+ dataframes.append(df)
144
+ except Exception as e:
145
+ st.error(f"Error loading table {table_name}: {e}")
146
+
147
+ st.session_state['dataframes'] = dataframes
148
+
149
+ # 4) Update previous_selection in session state
150
+ st.session_state["previous_selection"] = new_selection
151
+
152
+ st.success("Data loaded successfully!")
153
+
154
+ # --------------------------------------------------------------------------
155
+ # If no data is loaded, warn and stop
156
+ if not st.session_state['dataframes']:
157
+ st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.")
158
+ st.stop()
159
+
160
+ # **Always** display top 5 rows of each DataFrame if data is loaded
161
+ for idx, df in enumerate(st.session_state['dataframes']):
162
+ st.markdown(f"**Preview of loaded data:**")
163
+ st.dataframe(df.head(5))
164
+
165
+ # --- Chat Display Section ---
166
+ display_chat(st.session_state['history'])
167
+
168
+ # --- User Input Section ---
169
+ st.markdown("---")
170
+
171
+ with st.form(key="user_query_form"):
172
+ user_query = st.text_input(
173
+ "Ask a question about your data:",
174
+ placeholder="Type your question and press Enter..."
175
+ )
176
+ send_button = st.form_submit_button("Send")
177
+
178
+ if send_button and user_query.strip():
179
+ with st.spinner("Analyzing your data..."):
180
+ try:
181
+ response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes'])
182
+
183
+ if response['type'] == "dataframe":
184
+ df = response['value']
185
+ st.session_state['history'].append({
186
+ 'user': user_query,
187
+ 'type': 'dataframe',
188
+ 'bot': df # store the actual DataFrame
189
+ })
190
+ elif response['type'] == "plot":
191
+ plot_image = response['value']
192
+ st.session_state['history'].append({
193
+ 'user': user_query,
194
+ 'type': 'plot',
195
+ 'bot': plot_image
196
+ })
197
+ else: # string or any other text
198
+ text_response = response['value']
199
+ st.session_state['history'].append({
200
+ 'user': user_query,
201
+ 'type': 'string',
202
+ 'bot': text_response
203
+ })
204
+
205
+ # Rerun to refresh page and clear input
206
+ st.rerun()
207
+
208
+ except Exception as e:
209
+ st.error(f"Error: {e}")
210
+
211
+ elif send_button and not user_query.strip():
212
+ st.warning("Please enter a question before sending.")
auth.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def load_access_json(file_path: str) -> dict:
5
+ """Load the JSON file containing the allowed emails."""
6
+ with open(file_path, 'r') as f:
7
+ return json.load(f)
8
+
9
+ def authenticator(email: str, token: str, auth_token: str, access_json_path: str) -> bool:
10
+ """Check if the provided email and token are valid."""
11
+ emails_data = load_access_json(access_json_path)
12
+ email_list = emails_data["email"]
13
+
14
+ return (email.lower() in email_list) and (token == auth_token)
chat_ui.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ def display_chat(history):
5
+ """Renders the chat history with custom bubbles for each message."""
6
+ chat_container = st.container()
7
+ with chat_container:
8
+ for idx, chat in enumerate(history):
9
+ # --- User message ---
10
+ st.markdown(
11
+ f"""
12
+ <div class="chat-bubble user-bubble">
13
+ <strong>You:</strong> {chat['user']}
14
+ </div>
15
+ """,
16
+ unsafe_allow_html=True
17
+ )
18
+
19
+ # --- Bot bubble: use the 'type' key to decide how to render ---
20
+ st.markdown(
21
+ """
22
+ <div class="chat-bubble bot-bubble">
23
+ <strong>Bot:</strong>
24
+ """,
25
+ unsafe_allow_html=True,
26
+ )
27
+
28
+ response_type = chat.get('type', 'string') # default to 'string'
29
+ bot_response = chat['bot']
30
+
31
+ if response_type == 'dataframe' and isinstance(bot_response, pd.DataFrame):
32
+ # Show top 5 rows
33
+ df_to_display = bot_response
34
+ if len(df_to_display) > 5:
35
+ st.info("Showing the first 5 rows of the DataFrame.")
36
+ st.dataframe(df_to_display.head(5))
37
+
38
+ # Provide a CSV download
39
+ csv_data = df_to_display.to_csv(index=False).encode('utf-8')
40
+ st.download_button(
41
+ label="Download data as CSV",
42
+ data=csv_data,
43
+ file_name=f'result_{idx+1}.csv',
44
+ mime='text/csv',
45
+ key=f'download_{idx}'
46
+ )
47
+
48
+ elif response_type == 'plot':
49
+ # If it's an image object (e.g., PIL Image), show it
50
+ st.image(bot_response, use_container_width=True)
51
+
52
+ else: # "string" or any other text
53
+ st.markdown(f"{bot_response}", unsafe_allow_html=True)
54
+
55
+ st.markdown("</div>", unsafe_allow_html=True)
local_app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ import pandas as pd
5
+
6
+ # Local imports
7
+ from auth import authenticator
8
+ from utils import load_table_config, load_uploaded_files, display_table_descriptions
9
+ # from SmartQuery_GC import SmartQuery
10
+ from SmartQuery import SmartQuery
11
+ # If you use chat_ui.py:
12
+ from chat_ui import display_chat
13
+
14
+ load_dotenv()
15
+
16
+ # -----------------------------------------------------------------------
17
+ # Set page config
18
+ st.set_page_config(
19
+ page_title="MusoLyze",
20
+ page_icon="🤖",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded",
23
+ )
24
+
25
+ # -----------------------------------------------------------------------
26
+ # Constants
27
+ AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
28
+ ACCESS_JSON_PATH = "access.json"
29
+ TABLE_CONFIG_PATH = "table_config.json"
30
+ CSS_PATH = "style.css"
31
+
32
+ with open(CSS_PATH, "r") as f:
33
+ css_text = f.read()
34
+ st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True)
35
+
36
+ # -----------------------------------------------------------------------
37
+ # Initialize Session State
38
+ if "authenticated" not in st.session_state:
39
+ st.session_state["authenticated"] = False
40
+ if "history" not in st.session_state:
41
+ st.session_state["history"] = []
42
+ if "dataframes" not in st.session_state:
43
+ st.session_state["dataframes"] = []
44
+ if "brand" not in st.session_state:
45
+ st.session_state["brand"] = None
46
+
47
+ # NEW: Track the previous selection of brand, tables, and uploaded file names.
48
+ if "previous_selection" not in st.session_state:
49
+ st.session_state["previous_selection"] = {
50
+ "brand": None,
51
+ "tables": [],
52
+ "uploaded_files": []
53
+ }
54
+
55
+ # -----------------------------------------------------------------------
56
+ # LOGIN PAGE
57
+ if not st.session_state["authenticated"]:
58
+ st.markdown('<div class="login-container">', unsafe_allow_html=True)
59
+ st.markdown("## MusoLyze Login")
60
+ st.write("Please enter your email and authentication token to proceed.")
61
+
62
+ email = st.text_input("Email", placeholder="john.doe@example.com")
63
+ token = st.text_input("Token", type="password", placeholder="Enter your token")
64
+
65
+ if st.button("Log In"):
66
+ if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH):
67
+ st.session_state["authenticated"] = True
68
+ st.success("Logged in successfully!")
69
+ st.stop() # Force the script to end; next run user is authenticated.
70
+ else:
71
+ st.error("Invalid email or token. Please try again.")
72
+
73
+ st.markdown('</div>', unsafe_allow_html=True)
74
+ st.stop() # Stop execution so the rest of the page is not shown.
75
+
76
+ # -----------------------------------------------------------------------
77
+ # Main App: Load Data, Show Chat
78
+ st.title("💬 MusoLyze")
79
+
80
+ # SmartQuery instance
81
+ sq = SmartQuery()
82
+
83
+ # Load config file for database tables
84
+ table_config = load_table_config(TABLE_CONFIG_PATH)
85
+
86
+ # Sidebar for file upload and table selection
87
+ st.sidebar.title("Data Selection")
88
+
89
+ # 1. File upload
90
+ uploaded_files = st.sidebar.file_uploader(
91
+ "Upload CSV or Excel files",
92
+ type=['csv', 'xlsx', 'xls'],
93
+ accept_multiple_files=True
94
+ )
95
+
96
+ # 2. Brand selection
97
+ brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"])
98
+ st.session_state.brand = brand
99
+
100
+ # 3. Table selection
101
+ db_tables = st.sidebar.multiselect(
102
+ "Select tables from database",
103
+ options=list(table_config.keys()),
104
+ help="Select one or more tables to include in your data."
105
+ )
106
+
107
+ # Show table descriptions if user has selected any
108
+ display_table_descriptions(db_tables, table_config)
109
+
110
+ # 'Load Data' button
111
+ if st.sidebar.button("Load Data"):
112
+ # 1) Build the new selection object to compare with previous_selection.
113
+ new_selection = {
114
+ "brand": brand,
115
+ "tables": db_tables,
116
+ "uploaded_files": [f.name for f in uploaded_files] if uploaded_files else []
117
+ }
118
+
119
+ # 2) Compare new selection with old selection; if changed, reset history.
120
+ if new_selection != st.session_state["previous_selection"]:
121
+ st.session_state["history"] = []
122
+
123
+ # 3) Proceed with loading data
124
+ dataframes = []
125
+
126
+ # Load from uploaded files
127
+ if uploaded_files:
128
+ dataframes.extend(load_uploaded_files(uploaded_files))
129
+
130
+ # Load dataframes from selected tables
131
+ if db_tables:
132
+ for table_name in db_tables:
133
+ table_info = table_config[table_name]
134
+ source = table_info["source"]
135
+ try:
136
+ if source == 'Snowflake':
137
+ session = sq.snowflake_connection()
138
+ df = sq.read_snowflake_table(session, table_name, st.session_state.brand)
139
+ elif source == 'MySQL':
140
+ engine = sq.mysql_connection()
141
+ df = sq.read_mysql_table(engine, table_name, st.session_state.brand)
142
+ dataframes.append(df)
143
+ except Exception as e:
144
+ st.error(f"Error loading table {table_name}: {e}")
145
+
146
+ st.session_state['dataframes'] = dataframes
147
+
148
+ # 4) Update previous_selection in session state
149
+ st.session_state["previous_selection"] = new_selection
150
+
151
+ st.success("Data loaded successfully!")
152
+
153
+ # --------------------------------------------------------------------------
154
+ # If no data is loaded, warn and stop
155
+ if not st.session_state['dataframes']:
156
+ st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.")
157
+ st.stop()
158
+
159
+ # **Always** display top 5 rows of each DataFrame if data is loaded
160
+ for idx, df in enumerate(st.session_state['dataframes']):
161
+ st.markdown(f"**Preview of loaded data:**")
162
+ st.dataframe(df.head(5))
163
+
164
+ # --- Chat Display Section ---
165
+ display_chat(st.session_state['history'])
166
+
167
+ # --- User Input Section ---
168
+ st.markdown("---")
169
+
170
+ with st.form(key="user_query_form"):
171
+ user_query = st.text_input(
172
+ "Ask a question about your data:",
173
+ placeholder="Type your question and press Enter..."
174
+ )
175
+ send_button = st.form_submit_button("Send")
176
+
177
+ if send_button and user_query.strip():
178
+ with st.spinner("Analyzing your data..."):
179
+ try:
180
+ response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes'])
181
+
182
+ if response['type'] == "dataframe":
183
+ df = response['value']
184
+ st.session_state['history'].append({
185
+ 'user': user_query,
186
+ 'type': 'dataframe',
187
+ 'bot': df # store the actual DataFrame
188
+ })
189
+ elif response['type'] == "plot":
190
+ plot_image = response['value']
191
+ st.session_state['history'].append({
192
+ 'user': user_query,
193
+ 'type': 'plot',
194
+ 'bot': plot_image
195
+ })
196
+ else: # string or any other text
197
+ text_response = response['value']
198
+ st.session_state['history'].append({
199
+ 'user': user_query,
200
+ 'type': 'string',
201
+ 'bot': text_response
202
+ })
203
+
204
+ # Rerun to refresh page and clear input
205
+ st.rerun()
206
+
207
+ except Exception as e:
208
+ st.error(f"Error: {e}")
209
+
210
+ elif send_button and not user_query.strip():
211
+ st.warning("Please enter a question before sending.")
style.css ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Base Theme */
2
+ body {
3
+ background-color: #000000;
4
+ color: #FFD700;
5
+ }
6
+ .stButton>button {
7
+ background-color: #FFD700;
8
+ color: #000000;
9
+ }
10
+ .stTextInput>div>div>input {
11
+ color: #FFD700;
12
+ border-color: #FFD700 !important;
13
+ }
14
+ .stSidebar {
15
+ background-color: #1E1E1E;
16
+ }
17
+
18
+ /* Center the login container */
19
+ .login-container {
20
+ max-width: 400px;
21
+ margin: 0 auto;
22
+ padding: 2rem;
23
+ background-color: #1E1E1E;
24
+ border-radius: 10px;
25
+ }
26
+ .login-container h2 {
27
+ text-align: center;
28
+ }
29
+
30
+ /* Chat-like bubbles */
31
+ .chat-bubble {
32
+ padding: 10px;
33
+ border-radius: 10px;
34
+ margin: 5px 0;
35
+ max-width: 80%;
36
+ word-wrap: break-word;
37
+ }
38
+ .user-bubble {
39
+ background-color: #1E1E1E;
40
+ border: 1px solid #FFD700;
41
+ align-self: flex-start;
42
+ }
43
+ .bot-bubble {
44
+ background-color: #FFD700;
45
+ color: #000;
46
+ align-self: flex-end;
47
+ }
table_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "interactions":
3
+ {
4
+ "description": "This table contains interaction history of the users with all the Musora content.",
5
+ "source": "Snowflake",
6
+ "cols": ["user_id", "content_id", "brand","TIMESTAMP", "EVENT_TEXT", "CONTENT_TYPE", "DIFFICULTY"],
7
+ "query": "select * from ONLINE_RECSYS.PREPROCESSED.RECSYS_INTEACTIONS where brand = '{brand}'"
8
+ },
9
+ "contents":
10
+ {
11
+ "description": "This table contains information about Musora contents.",
12
+ "source": "Snowflake",
13
+ "cols": ["content_id", "brand", "content_title", "content_type", "content_description", "artist", "difficulty", "STYLE", "TOPIC","published_at"],
14
+ "query": "select * from ONLINE_RECSYS.PREPROCESSED.CONTENTS where brand = '{brand}'"
15
+ },
16
+ "users":
17
+ {
18
+ "description": "This table contains information about Musora users.",
19
+ "source": "Snowflake",
20
+ "cols": ["USER_ID", "BRAND", "DIFFICULTY", "SELF_REPORT_DIFFICULTY", "USER_PROFILE", "PERMISSION","EXPIRATION_DATE"],
21
+ "query": "select * from ONLINE_RECSYS.PREPROCESSED.USERS where brand = '{brand}'"
22
+ }
23
+ }
utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import json
4
+
5
+ def load_table_config(file_path: str) -> dict:
6
+ """Load the table configuration JSON."""
7
+ with open(file_path, 'r') as f:
8
+ return json.load(f)
9
+
10
+ def load_uploaded_files(uploaded_files):
11
+ """
12
+ Load dataframes from the uploaded files (CSV/Excel).
13
+ Returns a list of pandas DataFrames.
14
+ """
15
+ dataframes = []
16
+ for file in uploaded_files:
17
+ if file.name.endswith('.csv'):
18
+ df = pd.read_csv(file)
19
+ else:
20
+ df = pd.read_excel(file)
21
+ dataframes.append(df)
22
+ return dataframes
23
+
24
+ def display_table_descriptions(selected_tables, table_config):
25
+ """
26
+ Given a list of selected table names and the table config,
27
+ write out their descriptions in the sidebar.
28
+ """
29
+ if selected_tables:
30
+ st.sidebar.subheader("Table Descriptions")
31
+ for table_name in selected_tables:
32
+ description = table_config[table_name].get('description', "No description available.")
33
+ cols = table_config[table_name].get('cols', [])
34
+ st.sidebar.markdown(f"**{table_name}**: {description}")
35
+ st.sidebar.markdown(f"**Available columns**: {cols}")