from pandasai.llm import OpenAI from pandasai import Agent from pandasai import SmartDataframe, SmartDatalake from pandasai.responses.response_parser import ResponseParser from pandasai.responses.streamlit_response import StreamlitResponse from snowflake.snowpark import Session import json import pandas as pd from sqlalchemy import create_engine import os from dotenv import load_dotenv import streamlit as st load_dotenv() # ----------------------------------------------------------------------- key = st.secrets["PANDASAI_API_KEY"] os.environ['PANDASAI_API_KEY'] = key openai_llm = OpenAI( api_token=st.secrets["OPENAI_API"] ) # ----------------------------------------------------------------------- # ----------------------------------------------------------------------- class SmartQuery: """ class for interacting with dataframes using Natural Language """ def __init__(self): with open("table_config.json", "r") as f: self.config = json.load(f) def perform_query_on_dataframes(self, query, *dataframes, response_format=None): """ Performs a user-defined query on given pandas DataFrames using PandasAI. Parameters: - query (str): The user's query or instruction. - *dataframes (pd.DataFrame): Any number of pandas DataFrames. Returns: - The result of the query executed by PandasAI. """ dataframe_list = list(dataframes) num_dataframes = len(dataframe_list) config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser} if num_dataframes == 1: result = self.query_single_dataframe(query, dataframe_list[0], config) else: result = self.query_multiple_dataframes(query, dataframe_list, config) return result def query_single_dataframe(self, query, dataframe, config): agent = Agent(dataframe, config=config) response = agent.chat(query) return response def query_multiple_dataframes(self, query, dataframe_list, config): agent = SmartDatalake(dataframe_list, config=config) response = agent.chat(query) return response # ----------------------------------------------------------------------- def snowflake_connection(self): """ setting snowflake connection :return: """ conn = { "user": st.secrets["snowflake_user"], "password": st.secrets["snowflake_password"], "account": st.secrets["snowflake_account"], "role": st.secrets["snowflake_role"], "database": st.secrets["snowflake_database"], "warehouse": st.secrets["snowflake_warehouse"], "schema": st.secrets["snowflake_schema"] } try: session = Session.builder.configs(conn).create() return session except Exception as e: print(f"Error creating Snowflake session: {e}") raise e # ---------------------------------------------------------------------------------------------------- def read_snowflake_table(self, session, table_name, brand): """ reading tables from snowflake :param dataframe: :return: """ query = self._get_query(table_name, brand) # Connect to Snowflake try: dataframe = session.sql(query).to_pandas() dataframe.columns = dataframe.columns.str.lower() print(f"reading content table successfully") return dataframe except Exception as e: print(f"Error in reading table: {e}") # ---------------------------------------------------------------------------------------------------- def _get_query(self, table_name: str, brand: str) -> str: # Retrieve the base query template for the given table name base_query = self.config[table_name]["query"] # Insert the brand condition into the query query = base_query.format(brand=brand.lower()) return query # ---------------------------------------------------------------------------------------------------- def mysql_connection(self): # Setting up the MySQL connection parameters user = st.secrets["mysql_user"] password = st.secrets["mysql_password"] host = st.secrets["mysql_source"] database = st.secrets["mysql_schema"] try: engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}") return engine except Exception as e: print(f"Error creating MySQL engine: {e}") raise e # ---------------------------------------------------------------------------------------------------- def read_mysql_table(self, engine, table_name, brand): query = self._get_query(table_name, brand) with engine.connect() as conn: dataframe = pd.read_sql_query(query, conn) # Convert all column names to lowercase if not dataframe.columns = dataframe.columns.str.lower() return dataframe # ---------------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------------- class OutputParser(ResponseParser): def __init__(self, context) -> None: super().__init__(context) def parse(self, result): return result # ---------------------------------------------------------------------------------------------------- if __name__ == "__main__": query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents" query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes" query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? " dataframe_path = "data/recent_comment_test.csv" dataframe1 = pd.read_csv(dataframe_path) sq = SmartQuery() interactions_path = "DBT_ANALYTICS.CORE.FCT_CONTENT_INTERACTIONS" content_path = "DBT_ANALYTICS.CORE.DIM_CONTENT" session = sq.snowflake_connection() interactions_df = sq.read_snowflake_table(session, table_name="interactions", brand="drumeo") content_df = sq.read_snowflake_table(session, table_name="contents", brand="drumeo") # single dataframe # result = sq.perform_query_on_dataframes(query, dataframe, response_format="dataframe") # multiple dataframe result = sq.perform_query_on_dataframes(query_multi, interactions_df, content_df, response_format="dataframe") print(result)