Musolyze / SmartQuery.py
Danialebrat's picture
updating secret managing
b701a18
from pandasai.llm import OpenAI
from pandasai import Agent
from pandasai import SmartDataframe, SmartDatalake
from pandasai.responses.response_parser import ResponseParser
from pandasai.responses.streamlit_response import StreamlitResponse
from snowflake.snowpark import Session
import json
import pandas as pd
from sqlalchemy import create_engine
import os
from dotenv import load_dotenv
import streamlit as st
load_dotenv()
# -----------------------------------------------------------------------
key = st.secrets["PANDASAI_API_KEY"]
os.environ['PANDASAI_API_KEY'] = key
openai_llm = OpenAI(
api_token=st.secrets["OPENAI_API"]
)
# -----------------------------------------------------------------------
# -----------------------------------------------------------------------
class SmartQuery:
"""
class for interacting with dataframes using Natural Language
"""
def __init__(self):
with open("table_config.json", "r") as f:
self.config = json.load(f)
def perform_query_on_dataframes(self, query, *dataframes, response_format=None):
"""
Performs a user-defined query on given pandas DataFrames using PandasAI.
Parameters:
- query (str): The user's query or instruction.
- *dataframes (pd.DataFrame): Any number of pandas DataFrames.
Returns:
- The result of the query executed by PandasAI.
"""
dataframe_list = list(dataframes)
num_dataframes = len(dataframe_list)
config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser}
if num_dataframes == 1:
result = self.query_single_dataframe(query, dataframe_list[0], config)
else:
result = self.query_multiple_dataframes(query, dataframe_list, config)
return result
def query_single_dataframe(self, query, dataframe, config):
agent = Agent(dataframe, config=config)
response = agent.chat(query)
return response
def query_multiple_dataframes(self, query, dataframe_list, config):
agent = SmartDatalake(dataframe_list, config=config)
response = agent.chat(query)
return response
# -----------------------------------------------------------------------
def snowflake_connection(self):
"""
setting snowflake connection
:return:
"""
conn = {
"user": st.secrets["snowflake_user"],
"password": st.secrets["snowflake_password"],
"account": st.secrets["snowflake_account"],
"role": st.secrets["snowflake_role"],
"database": st.secrets["snowflake_database"],
"warehouse": st.secrets["snowflake_warehouse"],
"schema": st.secrets["snowflake_schema"]
}
try:
session = Session.builder.configs(conn).create()
return session
except Exception as e:
print(f"Error creating Snowflake session: {e}")
raise e
# ----------------------------------------------------------------------------------------------------
def read_snowflake_table(self, session, table_name, brand):
"""
reading tables from snowflake
:param dataframe:
:return:
"""
query = self._get_query(table_name, brand)
# Connect to Snowflake
try:
dataframe = session.sql(query).to_pandas()
dataframe.columns = dataframe.columns.str.lower()
print(f"reading content table successfully")
return dataframe
except Exception as e:
print(f"Error in reading table: {e}")
# ----------------------------------------------------------------------------------------------------
def _get_query(self, table_name: str, brand: str) -> str:
# Retrieve the base query template for the given table name
base_query = self.config[table_name]["query"]
# Insert the brand condition into the query
query = base_query.format(brand=brand.lower())
return query
# ----------------------------------------------------------------------------------------------------
def mysql_connection(self):
# Setting up the MySQL connection parameters
user = st.secrets["mysql_user"]
password = st.secrets["mysql_password"]
host = st.secrets["mysql_source"]
database = st.secrets["mysql_schema"]
try:
engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}")
return engine
except Exception as e:
print(f"Error creating MySQL engine: {e}")
raise e
# ----------------------------------------------------------------------------------------------------
def read_mysql_table(self, engine, table_name, brand):
query = self._get_query(table_name, brand)
with engine.connect() as conn:
dataframe = pd.read_sql_query(query, conn)
# Convert all column names to lowercase if not
dataframe.columns = dataframe.columns.str.lower()
return dataframe
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
class OutputParser(ResponseParser):
def __init__(self, context) -> None:
super().__init__(context)
def parse(self, result):
return result
# ----------------------------------------------------------------------------------------------------
if __name__ == "__main__":
query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents"
query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes"
query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? "
dataframe_path = "data/recent_comment_test.csv"
dataframe1 = pd.read_csv(dataframe_path)
sq = SmartQuery()
interactions_path = "DBT_ANALYTICS.CORE.FCT_CONTENT_INTERACTIONS"
content_path = "DBT_ANALYTICS.CORE.DIM_CONTENT"
session = sq.snowflake_connection()
interactions_df = sq.read_snowflake_table(session, table_name="interactions", brand="drumeo")
content_df = sq.read_snowflake_table(session, table_name="contents", brand="drumeo")
# single dataframe
# result = sq.perform_query_on_dataframes(query, dataframe, response_format="dataframe")
# multiple dataframe
result = sq.perform_query_on_dataframes(query_multi, interactions_df, content_df, response_format="dataframe")
print(result)