Spaces:
Sleeping
Sleeping
Commit ·
fe8a467
1
Parent(s): 806f16d
adding codes and files
Browse files- .dockerignore +5 -0
- .gitignore +5 -0
- README.md +2 -14
- SmartQuery.py +190 -0
- SmartQuery_GC.py +190 -0
- access.json +3 -0
- app.py +210 -2
- auth.py +14 -0
- chat_ui.py +55 -0
- local_app.py +211 -0
- style.css +47 -0
- table_config.json +23 -0
- utils.py +35 -0
.dockerignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the .streamlit directory and its contents
|
| 2 |
+
.streamlit/
|
| 3 |
+
|
| 4 |
+
# Ignore the .env file
|
| 5 |
+
.env
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the .streamlit directory and its contents
|
| 2 |
+
.streamlit/
|
| 3 |
+
|
| 4 |
+
# Ignore the .env file
|
| 5 |
+
.env
|
README.md
CHANGED
|
@@ -1,14 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
emoji: 🚀
|
| 4 |
-
colorFrom: purple
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: streamlit
|
| 7 |
-
sdk_version: 1.43.2
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
short_description: Analyzing Musora databases using natural language
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# SmartQuery
|
| 2 |
+
Ask questions from your data in natural language
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SmartQuery.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pandasai.llm import OpenAI
|
| 2 |
+
from pandasai import Agent
|
| 3 |
+
from pandasai import SmartDataframe, SmartDatalake
|
| 4 |
+
from pandasai.responses.response_parser import ResponseParser
|
| 5 |
+
from pandasai.responses.streamlit_response import StreamlitResponse
|
| 6 |
+
from snowflake.snowpark import Session
|
| 7 |
+
import json
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from sqlalchemy import create_engine
|
| 10 |
+
import os
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
import streamlit as st
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# -----------------------------------------------------------------------
|
| 16 |
+
key = st.secrets["pandasai"]["PANDASAI_API_KEY"]
|
| 17 |
+
os.environ['PANDASAI_API_KEY'] = key
|
| 18 |
+
openai_llm = OpenAI(
|
| 19 |
+
api_token=st.secrets["openai"]["OPENAI_API"]
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# -----------------------------------------------------------------------
|
| 23 |
+
# -----------------------------------------------------------------------
|
| 24 |
+
class SmartQuery:
|
| 25 |
+
"""
|
| 26 |
+
class for interacting with dataframes using Natural Language
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
with open("table_config.json", "r") as f:
|
| 31 |
+
self.config = json.load(f)
|
| 32 |
+
|
| 33 |
+
def perform_query_on_dataframes(self, query, *dataframes, response_format=None):
|
| 34 |
+
"""
|
| 35 |
+
Performs a user-defined query on given pandas DataFrames using PandasAI.
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
- query (str): The user's query or instruction.
|
| 39 |
+
- *dataframes (pd.DataFrame): Any number of pandas DataFrames.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
- The result of the query executed by PandasAI.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
dataframe_list = list(dataframes)
|
| 46 |
+
num_dataframes = len(dataframe_list)
|
| 47 |
+
|
| 48 |
+
config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser}
|
| 49 |
+
|
| 50 |
+
if num_dataframes == 1:
|
| 51 |
+
result = self.query_single_dataframe(query, dataframe_list[0], config)
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
result = self.query_multiple_dataframes(query, dataframe_list, config)
|
| 55 |
+
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
def query_single_dataframe(self, query, dataframe, config):
|
| 59 |
+
|
| 60 |
+
agent = Agent(dataframe, config=config)
|
| 61 |
+
response = agent.chat(query)
|
| 62 |
+
|
| 63 |
+
return response
|
| 64 |
+
|
| 65 |
+
def query_multiple_dataframes(self, query, dataframe_list, config):
|
| 66 |
+
|
| 67 |
+
agent = SmartDatalake(dataframe_list, config=config)
|
| 68 |
+
response = agent.chat(query)
|
| 69 |
+
|
| 70 |
+
return response
|
| 71 |
+
|
| 72 |
+
# -----------------------------------------------------------------------
|
| 73 |
+
def snowflake_connection(self):
|
| 74 |
+
"""
|
| 75 |
+
setting snowflake connection
|
| 76 |
+
:return:
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
conn = {
|
| 80 |
+
"user": os.environ.get("snowflake_user"),
|
| 81 |
+
"password": os.environ.get("snowflake_password"),
|
| 82 |
+
"account": os.environ.get("snowflake_account"),
|
| 83 |
+
"role": os.environ.get("snowflake_role"),
|
| 84 |
+
"database": os.environ.get("snowflake_database"),
|
| 85 |
+
"warehouse": os.environ.get("snowflake_warehouse"),
|
| 86 |
+
"schema": os.environ.get("snowflake_schema")
|
| 87 |
+
}
|
| 88 |
+
try:
|
| 89 |
+
session = Session.builder.configs(conn).create()
|
| 90 |
+
return session
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"Error creating Snowflake session: {e}")
|
| 93 |
+
raise e
|
| 94 |
+
|
| 95 |
+
# ----------------------------------------------------------------------------------------------------
|
| 96 |
+
def read_snowflake_table(self, session, table_name, brand):
|
| 97 |
+
"""
|
| 98 |
+
reading tables from snowflake
|
| 99 |
+
:param dataframe:
|
| 100 |
+
:return:
|
| 101 |
+
"""
|
| 102 |
+
query = self._get_query(table_name, brand)
|
| 103 |
+
|
| 104 |
+
# Connect to Snowflake
|
| 105 |
+
try:
|
| 106 |
+
dataframe = session.sql(query).to_pandas()
|
| 107 |
+
dataframe.columns = dataframe.columns.str.lower()
|
| 108 |
+
print(f"reading content table successfully")
|
| 109 |
+
return dataframe
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error in reading table: {e}")
|
| 112 |
+
|
| 113 |
+
# ----------------------------------------------------------------------------------------------------
|
| 114 |
+
def _get_query(self, table_name: str, brand: str) -> str:
|
| 115 |
+
# Retrieve the base query template for the given table name
|
| 116 |
+
base_query = self.config[table_name]["query"]
|
| 117 |
+
|
| 118 |
+
# Insert the brand condition into the query
|
| 119 |
+
query = base_query.format(brand=brand.lower())
|
| 120 |
+
|
| 121 |
+
return query
|
| 122 |
+
|
| 123 |
+
# ----------------------------------------------------------------------------------------------------
|
| 124 |
+
def mysql_connection(self):
|
| 125 |
+
|
| 126 |
+
# Setting up the MySQL connection parameters
|
| 127 |
+
user = os.environ.get("mysql_user")
|
| 128 |
+
password = os.environ.get("mysql_password")
|
| 129 |
+
host = os.environ.get("mysql_source")
|
| 130 |
+
database = os.environ.get("mysql_schema")
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}")
|
| 134 |
+
return engine
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Error creating MySQL engine: {e}")
|
| 137 |
+
raise e
|
| 138 |
+
|
| 139 |
+
# ----------------------------------------------------------------------------------------------------
|
| 140 |
+
def read_mysql_table(self, engine, table_name, brand):
|
| 141 |
+
|
| 142 |
+
query = self._get_query(table_name, brand)
|
| 143 |
+
|
| 144 |
+
with engine.connect() as conn:
|
| 145 |
+
dataframe = pd.read_sql_query(query, conn)
|
| 146 |
+
|
| 147 |
+
# Convert all column names to lowercase if not
|
| 148 |
+
dataframe.columns = dataframe.columns.str.lower()
|
| 149 |
+
|
| 150 |
+
return dataframe
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ----------------------------------------------------------------------------------------------------
|
| 154 |
+
# ----------------------------------------------------------------------------------------------------
|
| 155 |
+
class OutputParser(ResponseParser):
|
| 156 |
+
def __init__(self, context) -> None:
|
| 157 |
+
super().__init__(context)
|
| 158 |
+
|
| 159 |
+
def parse(self, result):
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
# ----------------------------------------------------------------------------------------------------
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
|
| 167 |
+
query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents"
|
| 168 |
+
query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes"
|
| 169 |
+
query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? "
|
| 170 |
+
|
| 171 |
+
dataframe_path = "data/recent_comment_test.csv"
|
| 172 |
+
|
| 173 |
+
dataframe1 = pd.read_csv(dataframe_path)
|
| 174 |
+
|
| 175 |
+
sq = SmartQuery()
|
| 176 |
+
interactions_path = "DBT_ANALYTICS.CORE.FCT_CONTENT_INTERACTIONS"
|
| 177 |
+
content_path = "DBT_ANALYTICS.CORE.DIM_CONTENT"
|
| 178 |
+
session = sq.snowflake_connection()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
interactions_df = sq.read_snowflake_table(session, table_name="interactions", brand="drumeo")
|
| 182 |
+
content_df = sq.read_snowflake_table(session, table_name="contents", brand="drumeo")
|
| 183 |
+
|
| 184 |
+
# single dataframe
|
| 185 |
+
# result = sq.perform_query_on_dataframes(query, dataframe, response_format="dataframe")
|
| 186 |
+
|
| 187 |
+
# multiple dataframe
|
| 188 |
+
result = sq.perform_query_on_dataframes(query_multi, interactions_df, content_df, response_format="dataframe")
|
| 189 |
+
|
| 190 |
+
print(result)
|
SmartQuery_GC.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# this class will use env variables to read secrets from google cloud
|
| 2 |
+
from pandasai import Agent
|
| 3 |
+
from pandasai import SmartDataframe, SmartDatalake
|
| 4 |
+
from pandasai.responses.response_parser import ResponseParser
|
| 5 |
+
from pandasai.llm.openai import OpenAI
|
| 6 |
+
from pandasai.responses.streamlit_response import StreamlitResponse
|
| 7 |
+
import pymysql
|
| 8 |
+
from pandasai.connectors import PandasConnector
|
| 9 |
+
from snowflake.snowpark import Session
|
| 10 |
+
import json
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from sqlalchemy import create_engine
|
| 13 |
+
import os
|
| 14 |
+
import streamlit as st
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
import datetime
|
| 19 |
+
|
| 20 |
+
# -----------------------------------------------------------------------
|
| 21 |
+
key = os.environ.get("PANDASAI_API_KEY")
|
| 22 |
+
os.environ['PANDASAI_API_KEY'] = key
|
| 23 |
+
|
| 24 |
+
# openai_llm = OpenAI(api_key=os.environ.get("OPENAI_API"))
|
| 25 |
+
|
| 26 |
+
openai_llm = OpenAI(
|
| 27 |
+
api_token=os.environ.get("OPENAI_API")
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# -----------------------------------------------------------------------
|
| 32 |
+
class SmartQuery:
|
| 33 |
+
"""
|
| 34 |
+
class for interacting with dataframes using Natural Language
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
with open("table_config.json", "r") as f:
|
| 39 |
+
self.config = json.load(f)
|
| 40 |
+
|
| 41 |
+
def perform_query_on_dataframes(self, query, *dataframes, response_format=None):
|
| 42 |
+
"""
|
| 43 |
+
Performs a user-defined query on given pandas DataFrames using PandasAI.
|
| 44 |
+
|
| 45 |
+
Parameters:
|
| 46 |
+
- query (str): The user's query or instruction.
|
| 47 |
+
- *dataframes (pd.DataFrame): Any number of pandas DataFrames.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
- The result of the query executed by PandasAI.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
dataframe_list = list(dataframes)
|
| 54 |
+
num_dataframes = len(dataframe_list)
|
| 55 |
+
|
| 56 |
+
config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser}
|
| 57 |
+
|
| 58 |
+
if num_dataframes == 1:
|
| 59 |
+
result = self.query_single_dataframe(query, dataframe_list[0], config)
|
| 60 |
+
|
| 61 |
+
else:
|
| 62 |
+
result = self.query_multiple_dataframes(query, dataframe_list, config)
|
| 63 |
+
|
| 64 |
+
return result
|
| 65 |
+
|
| 66 |
+
def query_single_dataframe(self, query, dataframe, config):
|
| 67 |
+
|
| 68 |
+
agent = Agent(dataframe, config=config)
|
| 69 |
+
response = agent.chat(query)
|
| 70 |
+
|
| 71 |
+
return response
|
| 72 |
+
|
| 73 |
+
def query_multiple_dataframes(self, query, dataframe_list, config):
|
| 74 |
+
|
| 75 |
+
agent = SmartDatalake(dataframe_list, config=config)
|
| 76 |
+
response = agent.chat(query)
|
| 77 |
+
|
| 78 |
+
return response
|
| 79 |
+
|
| 80 |
+
# -----------------------------------------------------------------------
|
| 81 |
+
def snowflake_connection(self):
|
| 82 |
+
"""
|
| 83 |
+
setting snowflake connection
|
| 84 |
+
:return:
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
conn = {
|
| 88 |
+
"user": os.environ.get("snowflake_user"),
|
| 89 |
+
"password": os.environ.get("snowflake_password"),
|
| 90 |
+
"account": os.environ.get("snowflake_account"),
|
| 91 |
+
"role": os.environ.get("snowflake_role"),
|
| 92 |
+
"database": os.environ.get("snowflake_database"),
|
| 93 |
+
"warehouse": os.environ.get("snowflake_warehouse"),
|
| 94 |
+
"schema": os.environ.get("snowflake_schema")
|
| 95 |
+
}
|
| 96 |
+
try:
|
| 97 |
+
session = Session.builder.configs(conn).create()
|
| 98 |
+
return session
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Error creating Snowflake session: {e}")
|
| 101 |
+
raise e
|
| 102 |
+
|
| 103 |
+
# ----------------------------------------------------------------------------------------------------
|
| 104 |
+
def read_snowflake_table(self, session, table_name, brand):
|
| 105 |
+
"""
|
| 106 |
+
reading tables from snowflake
|
| 107 |
+
:param dataframe:
|
| 108 |
+
:return:
|
| 109 |
+
"""
|
| 110 |
+
query = self._get_query(table_name, brand)
|
| 111 |
+
|
| 112 |
+
# Connect to Snowflake
|
| 113 |
+
try:
|
| 114 |
+
dataframe = session.sql(query).to_pandas()
|
| 115 |
+
dataframe.columns = dataframe.columns.str.lower()
|
| 116 |
+
print(f"reading content table successfully")
|
| 117 |
+
return dataframe
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error in reading table: {e}")
|
| 120 |
+
|
| 121 |
+
# ----------------------------------------------------------------------------------------------------
|
| 122 |
+
def _get_query(self, table_name: str, brand: str) -> str:
|
| 123 |
+
# Retrieve the base query template for the given table name
|
| 124 |
+
base_query = self.config[table_name]["query"]
|
| 125 |
+
|
| 126 |
+
# Insert the brand condition into the query
|
| 127 |
+
query = base_query.format(brand=brand.lower())
|
| 128 |
+
|
| 129 |
+
return query
|
| 130 |
+
|
| 131 |
+
# ----------------------------------------------------------------------------------------------------
|
| 132 |
+
def mysql_connection(self):
|
| 133 |
+
|
| 134 |
+
# Setting up the MySQL connection parameters
|
| 135 |
+
user = os.environ.get("mysql_user")
|
| 136 |
+
password = os.environ.get("mysql_password")
|
| 137 |
+
host = os.environ.get("mysql_source")
|
| 138 |
+
database = os.environ.get("mysql_schema")
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}")
|
| 142 |
+
return engine
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Error creating MySQL engine: {e}")
|
| 145 |
+
raise e
|
| 146 |
+
|
| 147 |
+
# ----------------------------------------------------------------------------------------------------
|
| 148 |
+
def read_mysql_table(self, engine, table_name, brand):
|
| 149 |
+
|
| 150 |
+
query = self._get_query(table_name, brand)
|
| 151 |
+
|
| 152 |
+
with engine.connect() as conn:
|
| 153 |
+
dataframe = pd.read_sql_query(query, conn)
|
| 154 |
+
|
| 155 |
+
# Convert all column names to lowercase if not
|
| 156 |
+
dataframe.columns = dataframe.columns.str.lower()
|
| 157 |
+
|
| 158 |
+
return dataframe
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ----------------------------------------------------------------------------------------------------
|
| 162 |
+
# ----------------------------------------------------------------------------------------------------
|
| 163 |
+
class OutputParser(ResponseParser):
|
| 164 |
+
def __init__(self, context) -> None:
|
| 165 |
+
super().__init__(context)
|
| 166 |
+
|
| 167 |
+
def parse(self, result):
|
| 168 |
+
return result
|
| 169 |
+
|
| 170 |
+
# ----------------------------------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
# query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents"
|
| 175 |
+
# query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes"
|
| 176 |
+
# query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? "
|
| 177 |
+
# query = "how many users do we have with 0 experience level?"
|
| 178 |
+
query = "select song content_type that have difficulty range of 0-3?"
|
| 179 |
+
#
|
| 180 |
+
# dataframe_path = "data/recent_comment_test.csv"
|
| 181 |
+
#
|
| 182 |
+
# dataframe1 = pd.read_csv(dataframe_path)
|
| 183 |
+
#
|
| 184 |
+
sq = SmartQuery()
|
| 185 |
+
session = sq.snowflake_connection()
|
| 186 |
+
dataframe = sq.read_snowflake_table(session, table_name="contents", brand="drumeo")
|
| 187 |
+
|
| 188 |
+
result = sq.perform_query_on_dataframes(query, dataframe)
|
| 189 |
+
|
| 190 |
+
print(result)
|
access.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"email": ["danial@musora.com", "danial.ebrat@gmail.com"]
|
| 3 |
+
}
|
app.py
CHANGED
|
@@ -1,4 +1,212 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
import streamlit as st
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import pandas as pd
|
| 5 |
|
| 6 |
+
# Local imports
|
| 7 |
+
from auth import authenticator
|
| 8 |
+
from utils import load_table_config, load_uploaded_files, display_table_descriptions
|
| 9 |
+
# from SmartQuery_GC import SmartQuery
|
| 10 |
+
from SmartQuery import SmartQuery
|
| 11 |
+
# If you use chat_ui.py:
|
| 12 |
+
from chat_ui import display_chat
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
# -----------------------------------------------------------------------
|
| 17 |
+
# Set page config
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="MusoLyze",
|
| 20 |
+
page_icon="🤖",
|
| 21 |
+
layout="wide",
|
| 22 |
+
initial_sidebar_state="expanded",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# -----------------------------------------------------------------------
|
| 26 |
+
# Constants
|
| 27 |
+
# AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
|
| 28 |
+
AUTH_TOKEN = st.secrets["token"]["AUTH_TOKEN"]
|
| 29 |
+
ACCESS_JSON_PATH = "access.json"
|
| 30 |
+
TABLE_CONFIG_PATH = "table_config.json"
|
| 31 |
+
CSS_PATH = "style.css"
|
| 32 |
+
|
| 33 |
+
with open(CSS_PATH, "r") as f:
|
| 34 |
+
css_text = f.read()
|
| 35 |
+
st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True)
|
| 36 |
+
|
| 37 |
+
# -----------------------------------------------------------------------
|
| 38 |
+
# Initialize Session State
|
| 39 |
+
if "authenticated" not in st.session_state:
|
| 40 |
+
st.session_state["authenticated"] = False
|
| 41 |
+
if "history" not in st.session_state:
|
| 42 |
+
st.session_state["history"] = []
|
| 43 |
+
if "dataframes" not in st.session_state:
|
| 44 |
+
st.session_state["dataframes"] = []
|
| 45 |
+
if "brand" not in st.session_state:
|
| 46 |
+
st.session_state["brand"] = None
|
| 47 |
+
|
| 48 |
+
# NEW: Track the previous selection of brand, tables, and uploaded file names.
|
| 49 |
+
if "previous_selection" not in st.session_state:
|
| 50 |
+
st.session_state["previous_selection"] = {
|
| 51 |
+
"brand": None,
|
| 52 |
+
"tables": [],
|
| 53 |
+
"uploaded_files": []
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# -----------------------------------------------------------------------
|
| 57 |
+
# LOGIN PAGE
|
| 58 |
+
if not st.session_state["authenticated"]:
|
| 59 |
+
st.markdown('<div class="login-container">', unsafe_allow_html=True)
|
| 60 |
+
st.markdown("## MusoLyze Login")
|
| 61 |
+
st.write("Please enter your email and authentication token to proceed.")
|
| 62 |
+
|
| 63 |
+
email = st.text_input("Email", placeholder="john.doe@example.com")
|
| 64 |
+
token = st.text_input("Token", type="password", placeholder="Enter your token")
|
| 65 |
+
|
| 66 |
+
if st.button("Log In"):
|
| 67 |
+
if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH):
|
| 68 |
+
st.session_state["authenticated"] = True
|
| 69 |
+
st.success("Logged in successfully!")
|
| 70 |
+
st.stop() # Force the script to end; next run user is authenticated.
|
| 71 |
+
else:
|
| 72 |
+
st.error("Invalid email or token. Please try again.")
|
| 73 |
+
|
| 74 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 75 |
+
st.stop() # Stop execution so the rest of the page is not shown.
|
| 76 |
+
|
| 77 |
+
# -----------------------------------------------------------------------
|
| 78 |
+
# Main App: Load Data, Show Chat
|
| 79 |
+
st.title("💬 MusoLyze")
|
| 80 |
+
|
| 81 |
+
# SmartQuery instance
|
| 82 |
+
sq = SmartQuery()
|
| 83 |
+
|
| 84 |
+
# Load config file for database tables
|
| 85 |
+
table_config = load_table_config(TABLE_CONFIG_PATH)
|
| 86 |
+
|
| 87 |
+
# Sidebar for file upload and table selection
|
| 88 |
+
st.sidebar.title("Data Selection")
|
| 89 |
+
|
| 90 |
+
# 1. File upload
|
| 91 |
+
uploaded_files = st.sidebar.file_uploader(
|
| 92 |
+
"Upload CSV or Excel files",
|
| 93 |
+
type=['csv', 'xlsx', 'xls'],
|
| 94 |
+
accept_multiple_files=True
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# 2. Brand selection
|
| 98 |
+
brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"])
|
| 99 |
+
st.session_state.brand = brand
|
| 100 |
+
|
| 101 |
+
# 3. Table selection
|
| 102 |
+
db_tables = st.sidebar.multiselect(
|
| 103 |
+
"Select tables from database",
|
| 104 |
+
options=list(table_config.keys()),
|
| 105 |
+
help="Select one or more tables to include in your data."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Show table descriptions if user has selected any
|
| 109 |
+
display_table_descriptions(db_tables, table_config)
|
| 110 |
+
|
| 111 |
+
# 'Load Data' button
|
| 112 |
+
if st.sidebar.button("Load Data"):
|
| 113 |
+
# 1) Build the new selection object to compare with previous_selection.
|
| 114 |
+
new_selection = {
|
| 115 |
+
"brand": brand,
|
| 116 |
+
"tables": db_tables,
|
| 117 |
+
"uploaded_files": [f.name for f in uploaded_files] if uploaded_files else []
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
# 2) Compare new selection with old selection; if changed, reset history.
|
| 121 |
+
if new_selection != st.session_state["previous_selection"]:
|
| 122 |
+
st.session_state["history"] = []
|
| 123 |
+
|
| 124 |
+
# 3) Proceed with loading data
|
| 125 |
+
dataframes = []
|
| 126 |
+
|
| 127 |
+
# Load from uploaded files
|
| 128 |
+
if uploaded_files:
|
| 129 |
+
dataframes.extend(load_uploaded_files(uploaded_files))
|
| 130 |
+
|
| 131 |
+
# Load dataframes from selected tables
|
| 132 |
+
if db_tables:
|
| 133 |
+
for table_name in db_tables:
|
| 134 |
+
table_info = table_config[table_name]
|
| 135 |
+
source = table_info["source"]
|
| 136 |
+
try:
|
| 137 |
+
if source == 'Snowflake':
|
| 138 |
+
session = sq.snowflake_connection()
|
| 139 |
+
df = sq.read_snowflake_table(session, table_name, st.session_state.brand)
|
| 140 |
+
elif source == 'MySQL':
|
| 141 |
+
engine = sq.mysql_connection()
|
| 142 |
+
df = sq.read_mysql_table(engine, table_name, st.session_state.brand)
|
| 143 |
+
dataframes.append(df)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
st.error(f"Error loading table {table_name}: {e}")
|
| 146 |
+
|
| 147 |
+
st.session_state['dataframes'] = dataframes
|
| 148 |
+
|
| 149 |
+
# 4) Update previous_selection in session state
|
| 150 |
+
st.session_state["previous_selection"] = new_selection
|
| 151 |
+
|
| 152 |
+
st.success("Data loaded successfully!")
|
| 153 |
+
|
| 154 |
+
# --------------------------------------------------------------------------
|
| 155 |
+
# If no data is loaded, warn and stop
|
| 156 |
+
if not st.session_state['dataframes']:
|
| 157 |
+
st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.")
|
| 158 |
+
st.stop()
|
| 159 |
+
|
| 160 |
+
# **Always** display top 5 rows of each DataFrame if data is loaded
|
| 161 |
+
for idx, df in enumerate(st.session_state['dataframes']):
|
| 162 |
+
st.markdown(f"**Preview of loaded data:**")
|
| 163 |
+
st.dataframe(df.head(5))
|
| 164 |
+
|
| 165 |
+
# --- Chat Display Section ---
|
| 166 |
+
display_chat(st.session_state['history'])
|
| 167 |
+
|
| 168 |
+
# --- User Input Section ---
|
| 169 |
+
st.markdown("---")
|
| 170 |
+
|
| 171 |
+
with st.form(key="user_query_form"):
|
| 172 |
+
user_query = st.text_input(
|
| 173 |
+
"Ask a question about your data:",
|
| 174 |
+
placeholder="Type your question and press Enter..."
|
| 175 |
+
)
|
| 176 |
+
send_button = st.form_submit_button("Send")
|
| 177 |
+
|
| 178 |
+
if send_button and user_query.strip():
|
| 179 |
+
with st.spinner("Analyzing your data..."):
|
| 180 |
+
try:
|
| 181 |
+
response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes'])
|
| 182 |
+
|
| 183 |
+
if response['type'] == "dataframe":
|
| 184 |
+
df = response['value']
|
| 185 |
+
st.session_state['history'].append({
|
| 186 |
+
'user': user_query,
|
| 187 |
+
'type': 'dataframe',
|
| 188 |
+
'bot': df # store the actual DataFrame
|
| 189 |
+
})
|
| 190 |
+
elif response['type'] == "plot":
|
| 191 |
+
plot_image = response['value']
|
| 192 |
+
st.session_state['history'].append({
|
| 193 |
+
'user': user_query,
|
| 194 |
+
'type': 'plot',
|
| 195 |
+
'bot': plot_image
|
| 196 |
+
})
|
| 197 |
+
else: # string or any other text
|
| 198 |
+
text_response = response['value']
|
| 199 |
+
st.session_state['history'].append({
|
| 200 |
+
'user': user_query,
|
| 201 |
+
'type': 'string',
|
| 202 |
+
'bot': text_response
|
| 203 |
+
})
|
| 204 |
+
|
| 205 |
+
# Rerun to refresh page and clear input
|
| 206 |
+
st.rerun()
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
st.error(f"Error: {e}")
|
| 210 |
+
|
| 211 |
+
elif send_button and not user_query.strip():
|
| 212 |
+
st.warning("Please enter a question before sending.")
|
auth.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_access_json(file_path: str) -> dict:
|
| 5 |
+
"""Load the JSON file containing the allowed emails."""
|
| 6 |
+
with open(file_path, 'r') as f:
|
| 7 |
+
return json.load(f)
|
| 8 |
+
|
| 9 |
+
def authenticator(email: str, token: str, auth_token: str, access_json_path: str) -> bool:
|
| 10 |
+
"""Check if the provided email and token are valid."""
|
| 11 |
+
emails_data = load_access_json(access_json_path)
|
| 12 |
+
email_list = emails_data["email"]
|
| 13 |
+
|
| 14 |
+
return (email.lower() in email_list) and (token == auth_token)
|
chat_ui.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
def display_chat(history):
|
| 5 |
+
"""Renders the chat history with custom bubbles for each message."""
|
| 6 |
+
chat_container = st.container()
|
| 7 |
+
with chat_container:
|
| 8 |
+
for idx, chat in enumerate(history):
|
| 9 |
+
# --- User message ---
|
| 10 |
+
st.markdown(
|
| 11 |
+
f"""
|
| 12 |
+
<div class="chat-bubble user-bubble">
|
| 13 |
+
<strong>You:</strong> {chat['user']}
|
| 14 |
+
</div>
|
| 15 |
+
""",
|
| 16 |
+
unsafe_allow_html=True
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# --- Bot bubble: use the 'type' key to decide how to render ---
|
| 20 |
+
st.markdown(
|
| 21 |
+
"""
|
| 22 |
+
<div class="chat-bubble bot-bubble">
|
| 23 |
+
<strong>Bot:</strong>
|
| 24 |
+
""",
|
| 25 |
+
unsafe_allow_html=True,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
response_type = chat.get('type', 'string') # default to 'string'
|
| 29 |
+
bot_response = chat['bot']
|
| 30 |
+
|
| 31 |
+
if response_type == 'dataframe' and isinstance(bot_response, pd.DataFrame):
|
| 32 |
+
# Show top 5 rows
|
| 33 |
+
df_to_display = bot_response
|
| 34 |
+
if len(df_to_display) > 5:
|
| 35 |
+
st.info("Showing the first 5 rows of the DataFrame.")
|
| 36 |
+
st.dataframe(df_to_display.head(5))
|
| 37 |
+
|
| 38 |
+
# Provide a CSV download
|
| 39 |
+
csv_data = df_to_display.to_csv(index=False).encode('utf-8')
|
| 40 |
+
st.download_button(
|
| 41 |
+
label="Download data as CSV",
|
| 42 |
+
data=csv_data,
|
| 43 |
+
file_name=f'result_{idx+1}.csv',
|
| 44 |
+
mime='text/csv',
|
| 45 |
+
key=f'download_{idx}'
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
elif response_type == 'plot':
|
| 49 |
+
# If it's an image object (e.g., PIL Image), show it
|
| 50 |
+
st.image(bot_response, use_container_width=True)
|
| 51 |
+
|
| 52 |
+
else: # "string" or any other text
|
| 53 |
+
st.markdown(f"{bot_response}", unsafe_allow_html=True)
|
| 54 |
+
|
| 55 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
local_app.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
# Local imports
|
| 7 |
+
from auth import authenticator
|
| 8 |
+
from utils import load_table_config, load_uploaded_files, display_table_descriptions
|
| 9 |
+
# from SmartQuery_GC import SmartQuery
|
| 10 |
+
from SmartQuery import SmartQuery
|
| 11 |
+
# If you use chat_ui.py:
|
| 12 |
+
from chat_ui import display_chat
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
# -----------------------------------------------------------------------
|
| 17 |
+
# Set page config
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="MusoLyze",
|
| 20 |
+
page_icon="🤖",
|
| 21 |
+
layout="wide",
|
| 22 |
+
initial_sidebar_state="expanded",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# -----------------------------------------------------------------------
|
| 26 |
+
# Constants
|
| 27 |
+
AUTH_TOKEN = os.environ.get("AUTH_TOKEN")
|
| 28 |
+
ACCESS_JSON_PATH = "access.json"
|
| 29 |
+
TABLE_CONFIG_PATH = "table_config.json"
|
| 30 |
+
CSS_PATH = "style.css"
|
| 31 |
+
|
| 32 |
+
with open(CSS_PATH, "r") as f:
|
| 33 |
+
css_text = f.read()
|
| 34 |
+
st.markdown(f"<style>{css_text}</style>", unsafe_allow_html=True)
|
| 35 |
+
|
| 36 |
+
# -----------------------------------------------------------------------
|
| 37 |
+
# Initialize Session State
|
| 38 |
+
if "authenticated" not in st.session_state:
|
| 39 |
+
st.session_state["authenticated"] = False
|
| 40 |
+
if "history" not in st.session_state:
|
| 41 |
+
st.session_state["history"] = []
|
| 42 |
+
if "dataframes" not in st.session_state:
|
| 43 |
+
st.session_state["dataframes"] = []
|
| 44 |
+
if "brand" not in st.session_state:
|
| 45 |
+
st.session_state["brand"] = None
|
| 46 |
+
|
| 47 |
+
# NEW: Track the previous selection of brand, tables, and uploaded file names.
|
| 48 |
+
if "previous_selection" not in st.session_state:
|
| 49 |
+
st.session_state["previous_selection"] = {
|
| 50 |
+
"brand": None,
|
| 51 |
+
"tables": [],
|
| 52 |
+
"uploaded_files": []
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# -----------------------------------------------------------------------
|
| 56 |
+
# LOGIN PAGE
|
| 57 |
+
if not st.session_state["authenticated"]:
|
| 58 |
+
st.markdown('<div class="login-container">', unsafe_allow_html=True)
|
| 59 |
+
st.markdown("## MusoLyze Login")
|
| 60 |
+
st.write("Please enter your email and authentication token to proceed.")
|
| 61 |
+
|
| 62 |
+
email = st.text_input("Email", placeholder="john.doe@example.com")
|
| 63 |
+
token = st.text_input("Token", type="password", placeholder="Enter your token")
|
| 64 |
+
|
| 65 |
+
if st.button("Log In"):
|
| 66 |
+
if authenticator(email, token, AUTH_TOKEN, ACCESS_JSON_PATH):
|
| 67 |
+
st.session_state["authenticated"] = True
|
| 68 |
+
st.success("Logged in successfully!")
|
| 69 |
+
st.stop() # Force the script to end; next run user is authenticated.
|
| 70 |
+
else:
|
| 71 |
+
st.error("Invalid email or token. Please try again.")
|
| 72 |
+
|
| 73 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 74 |
+
st.stop() # Stop execution so the rest of the page is not shown.
|
| 75 |
+
|
| 76 |
+
# -----------------------------------------------------------------------
|
| 77 |
+
# Main App: Load Data, Show Chat
|
| 78 |
+
st.title("💬 MusoLyze")
|
| 79 |
+
|
| 80 |
+
# SmartQuery instance
|
| 81 |
+
sq = SmartQuery()
|
| 82 |
+
|
| 83 |
+
# Load config file for database tables
|
| 84 |
+
table_config = load_table_config(TABLE_CONFIG_PATH)
|
| 85 |
+
|
| 86 |
+
# Sidebar for file upload and table selection
|
| 87 |
+
st.sidebar.title("Data Selection")
|
| 88 |
+
|
| 89 |
+
# 1. File upload
|
| 90 |
+
uploaded_files = st.sidebar.file_uploader(
|
| 91 |
+
"Upload CSV or Excel files",
|
| 92 |
+
type=['csv', 'xlsx', 'xls'],
|
| 93 |
+
accept_multiple_files=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 2. Brand selection
|
| 97 |
+
brand = st.sidebar.selectbox("Choose your brand.", ["drumeo", "guitareo", "pianote", "singeo"])
|
| 98 |
+
st.session_state.brand = brand
|
| 99 |
+
|
| 100 |
+
# 3. Table selection
|
| 101 |
+
db_tables = st.sidebar.multiselect(
|
| 102 |
+
"Select tables from database",
|
| 103 |
+
options=list(table_config.keys()),
|
| 104 |
+
help="Select one or more tables to include in your data."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Show table descriptions if user has selected any
|
| 108 |
+
display_table_descriptions(db_tables, table_config)
|
| 109 |
+
|
| 110 |
+
# 'Load Data' button
|
| 111 |
+
if st.sidebar.button("Load Data"):
|
| 112 |
+
# 1) Build the new selection object to compare with previous_selection.
|
| 113 |
+
new_selection = {
|
| 114 |
+
"brand": brand,
|
| 115 |
+
"tables": db_tables,
|
| 116 |
+
"uploaded_files": [f.name for f in uploaded_files] if uploaded_files else []
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# 2) Compare new selection with old selection; if changed, reset history.
|
| 120 |
+
if new_selection != st.session_state["previous_selection"]:
|
| 121 |
+
st.session_state["history"] = []
|
| 122 |
+
|
| 123 |
+
# 3) Proceed with loading data
|
| 124 |
+
dataframes = []
|
| 125 |
+
|
| 126 |
+
# Load from uploaded files
|
| 127 |
+
if uploaded_files:
|
| 128 |
+
dataframes.extend(load_uploaded_files(uploaded_files))
|
| 129 |
+
|
| 130 |
+
# Load dataframes from selected tables
|
| 131 |
+
if db_tables:
|
| 132 |
+
for table_name in db_tables:
|
| 133 |
+
table_info = table_config[table_name]
|
| 134 |
+
source = table_info["source"]
|
| 135 |
+
try:
|
| 136 |
+
if source == 'Snowflake':
|
| 137 |
+
session = sq.snowflake_connection()
|
| 138 |
+
df = sq.read_snowflake_table(session, table_name, st.session_state.brand)
|
| 139 |
+
elif source == 'MySQL':
|
| 140 |
+
engine = sq.mysql_connection()
|
| 141 |
+
df = sq.read_mysql_table(engine, table_name, st.session_state.brand)
|
| 142 |
+
dataframes.append(df)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
st.error(f"Error loading table {table_name}: {e}")
|
| 145 |
+
|
| 146 |
+
st.session_state['dataframes'] = dataframes
|
| 147 |
+
|
| 148 |
+
# 4) Update previous_selection in session state
|
| 149 |
+
st.session_state["previous_selection"] = new_selection
|
| 150 |
+
|
| 151 |
+
st.success("Data loaded successfully!")
|
| 152 |
+
|
| 153 |
+
# --------------------------------------------------------------------------
|
| 154 |
+
# If no data is loaded, warn and stop
|
| 155 |
+
if not st.session_state['dataframes']:
|
| 156 |
+
st.warning("Please upload at least one file or select a table from the database, then click 'Load Data'.")
|
| 157 |
+
st.stop()
|
| 158 |
+
|
| 159 |
+
# **Always** display top 5 rows of each DataFrame if data is loaded
|
| 160 |
+
for idx, df in enumerate(st.session_state['dataframes']):
|
| 161 |
+
st.markdown(f"**Preview of loaded data:**")
|
| 162 |
+
st.dataframe(df.head(5))
|
| 163 |
+
|
| 164 |
+
# --- Chat Display Section ---
|
| 165 |
+
display_chat(st.session_state['history'])
|
| 166 |
+
|
| 167 |
+
# --- User Input Section ---
|
| 168 |
+
st.markdown("---")
|
| 169 |
+
|
| 170 |
+
with st.form(key="user_query_form"):
|
| 171 |
+
user_query = st.text_input(
|
| 172 |
+
"Ask a question about your data:",
|
| 173 |
+
placeholder="Type your question and press Enter..."
|
| 174 |
+
)
|
| 175 |
+
send_button = st.form_submit_button("Send")
|
| 176 |
+
|
| 177 |
+
if send_button and user_query.strip():
|
| 178 |
+
with st.spinner("Analyzing your data..."):
|
| 179 |
+
try:
|
| 180 |
+
response = sq.perform_query_on_dataframes(user_query, *st.session_state['dataframes'])
|
| 181 |
+
|
| 182 |
+
if response['type'] == "dataframe":
|
| 183 |
+
df = response['value']
|
| 184 |
+
st.session_state['history'].append({
|
| 185 |
+
'user': user_query,
|
| 186 |
+
'type': 'dataframe',
|
| 187 |
+
'bot': df # store the actual DataFrame
|
| 188 |
+
})
|
| 189 |
+
elif response['type'] == "plot":
|
| 190 |
+
plot_image = response['value']
|
| 191 |
+
st.session_state['history'].append({
|
| 192 |
+
'user': user_query,
|
| 193 |
+
'type': 'plot',
|
| 194 |
+
'bot': plot_image
|
| 195 |
+
})
|
| 196 |
+
else: # string or any other text
|
| 197 |
+
text_response = response['value']
|
| 198 |
+
st.session_state['history'].append({
|
| 199 |
+
'user': user_query,
|
| 200 |
+
'type': 'string',
|
| 201 |
+
'bot': text_response
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
# Rerun to refresh page and clear input
|
| 205 |
+
st.rerun()
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
st.error(f"Error: {e}")
|
| 209 |
+
|
| 210 |
+
elif send_button and not user_query.strip():
|
| 211 |
+
st.warning("Please enter a question before sending.")
|
style.css
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Base Theme */
|
| 2 |
+
body {
|
| 3 |
+
background-color: #000000;
|
| 4 |
+
color: #FFD700;
|
| 5 |
+
}
|
| 6 |
+
.stButton>button {
|
| 7 |
+
background-color: #FFD700;
|
| 8 |
+
color: #000000;
|
| 9 |
+
}
|
| 10 |
+
.stTextInput>div>div>input {
|
| 11 |
+
color: #FFD700;
|
| 12 |
+
border-color: #FFD700 !important;
|
| 13 |
+
}
|
| 14 |
+
.stSidebar {
|
| 15 |
+
background-color: #1E1E1E;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
/* Center the login container */
|
| 19 |
+
.login-container {
|
| 20 |
+
max-width: 400px;
|
| 21 |
+
margin: 0 auto;
|
| 22 |
+
padding: 2rem;
|
| 23 |
+
background-color: #1E1E1E;
|
| 24 |
+
border-radius: 10px;
|
| 25 |
+
}
|
| 26 |
+
.login-container h2 {
|
| 27 |
+
text-align: center;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/* Chat-like bubbles */
|
| 31 |
+
.chat-bubble {
|
| 32 |
+
padding: 10px;
|
| 33 |
+
border-radius: 10px;
|
| 34 |
+
margin: 5px 0;
|
| 35 |
+
max-width: 80%;
|
| 36 |
+
word-wrap: break-word;
|
| 37 |
+
}
|
| 38 |
+
.user-bubble {
|
| 39 |
+
background-color: #1E1E1E;
|
| 40 |
+
border: 1px solid #FFD700;
|
| 41 |
+
align-self: flex-start;
|
| 42 |
+
}
|
| 43 |
+
.bot-bubble {
|
| 44 |
+
background-color: #FFD700;
|
| 45 |
+
color: #000;
|
| 46 |
+
align-self: flex-end;
|
| 47 |
+
}
|
table_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"interactions":
|
| 3 |
+
{
|
| 4 |
+
"description": "This table contains interaction history of the users with all the Musora content.",
|
| 5 |
+
"source": "Snowflake",
|
| 6 |
+
"cols": ["user_id", "content_id", "brand","TIMESTAMP", "EVENT_TEXT", "CONTENT_TYPE", "DIFFICULTY"],
|
| 7 |
+
"query": "select * from ONLINE_RECSYS.PREPROCESSED.RECSYS_INTEACTIONS where brand = '{brand}'"
|
| 8 |
+
},
|
| 9 |
+
"contents":
|
| 10 |
+
{
|
| 11 |
+
"description": "This table contains information about Musora contents.",
|
| 12 |
+
"source": "Snowflake",
|
| 13 |
+
"cols": ["content_id", "brand", "content_title", "content_type", "content_description", "artist", "difficulty", "STYLE", "TOPIC","published_at"],
|
| 14 |
+
"query": "select * from ONLINE_RECSYS.PREPROCESSED.CONTENTS where brand = '{brand}'"
|
| 15 |
+
},
|
| 16 |
+
"users":
|
| 17 |
+
{
|
| 18 |
+
"description": "This table contains information about Musora users.",
|
| 19 |
+
"source": "Snowflake",
|
| 20 |
+
"cols": ["USER_ID", "BRAND", "DIFFICULTY", "SELF_REPORT_DIFFICULTY", "USER_PROFILE", "PERMISSION","EXPIRATION_DATE"],
|
| 21 |
+
"query": "select * from ONLINE_RECSYS.PREPROCESSED.USERS where brand = '{brand}'"
|
| 22 |
+
}
|
| 23 |
+
}
|
utils.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
def load_table_config(file_path: str) -> dict:
|
| 6 |
+
"""Load the table configuration JSON."""
|
| 7 |
+
with open(file_path, 'r') as f:
|
| 8 |
+
return json.load(f)
|
| 9 |
+
|
| 10 |
+
def load_uploaded_files(uploaded_files):
|
| 11 |
+
"""
|
| 12 |
+
Load dataframes from the uploaded files (CSV/Excel).
|
| 13 |
+
Returns a list of pandas DataFrames.
|
| 14 |
+
"""
|
| 15 |
+
dataframes = []
|
| 16 |
+
for file in uploaded_files:
|
| 17 |
+
if file.name.endswith('.csv'):
|
| 18 |
+
df = pd.read_csv(file)
|
| 19 |
+
else:
|
| 20 |
+
df = pd.read_excel(file)
|
| 21 |
+
dataframes.append(df)
|
| 22 |
+
return dataframes
|
| 23 |
+
|
| 24 |
+
def display_table_descriptions(selected_tables, table_config):
|
| 25 |
+
"""
|
| 26 |
+
Given a list of selected table names and the table config,
|
| 27 |
+
write out their descriptions in the sidebar.
|
| 28 |
+
"""
|
| 29 |
+
if selected_tables:
|
| 30 |
+
st.sidebar.subheader("Table Descriptions")
|
| 31 |
+
for table_name in selected_tables:
|
| 32 |
+
description = table_config[table_name].get('description', "No description available.")
|
| 33 |
+
cols = table_config[table_name].get('cols', [])
|
| 34 |
+
st.sidebar.markdown(f"**{table_name}**: {description}")
|
| 35 |
+
st.sidebar.markdown(f"**Available columns**: {cols}")
|