# this class will use env variables to read secrets from google cloud from pandasai import Agent from pandasai import SmartDataframe, SmartDatalake from pandasai.responses.response_parser import ResponseParser from pandasai.llm.openai import OpenAI from pandasai.responses.streamlit_response import StreamlitResponse import pymysql from pandasai.connectors import PandasConnector from snowflake.snowpark import Session import json import pandas as pd from sqlalchemy import create_engine import os import streamlit as st from dotenv import load_dotenv load_dotenv() import datetime # ----------------------------------------------------------------------- key = os.environ.get("PANDASAI_API_KEY") os.environ['PANDASAI_API_KEY'] = key # openai_llm = OpenAI(api_key=os.environ.get("OPENAI_API")) openai_llm = OpenAI( api_token=os.environ.get("OPENAI_API") ) # ----------------------------------------------------------------------- class SmartQuery: """ class for interacting with dataframes using Natural Language """ def __init__(self): with open("table_config.json", "r") as f: self.config = json.load(f) def perform_query_on_dataframes(self, query, *dataframes, response_format=None): """ Performs a user-defined query on given pandas DataFrames using PandasAI. Parameters: - query (str): The user's query or instruction. - *dataframes (pd.DataFrame): Any number of pandas DataFrames. Returns: - The result of the query executed by PandasAI. """ dataframe_list = list(dataframes) num_dataframes = len(dataframe_list) config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser} if num_dataframes == 1: result = self.query_single_dataframe(query, dataframe_list[0], config) else: result = self.query_multiple_dataframes(query, dataframe_list, config) return result def query_single_dataframe(self, query, dataframe, config): agent = Agent(dataframe, config=config) response = agent.chat(query) return response def query_multiple_dataframes(self, query, dataframe_list, config): agent = SmartDatalake(dataframe_list, config=config) response = agent.chat(query) return response # ----------------------------------------------------------------------- def snowflake_connection(self): """ setting snowflake connection :return: """ conn = { "user": os.environ.get("snowflake_user"), "password": os.environ.get("snowflake_password"), "account": os.environ.get("snowflake_account"), "role": os.environ.get("snowflake_role"), "database": os.environ.get("snowflake_database"), "warehouse": os.environ.get("snowflake_warehouse"), "schema": os.environ.get("snowflake_schema") } try: session = Session.builder.configs(conn).create() return session except Exception as e: print(f"Error creating Snowflake session: {e}") raise e # ---------------------------------------------------------------------------------------------------- def read_snowflake_table(self, session, table_name, brand): """ reading tables from snowflake :param dataframe: :return: """ query = self._get_query(table_name, brand) # Connect to Snowflake try: dataframe = session.sql(query).to_pandas() dataframe.columns = dataframe.columns.str.lower() print(f"reading content table successfully") return dataframe except Exception as e: print(f"Error in reading table: {e}") # ---------------------------------------------------------------------------------------------------- def _get_query(self, table_name: str, brand: str) -> str: # Retrieve the base query template for the given table name base_query = self.config[table_name]["query"] # Insert the brand condition into the query query = base_query.format(brand=brand.lower()) return query # ---------------------------------------------------------------------------------------------------- def mysql_connection(self): # Setting up the MySQL connection parameters user = os.environ.get("mysql_user") password = os.environ.get("mysql_password") host = os.environ.get("mysql_source") database = os.environ.get("mysql_schema") try: engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}") return engine except Exception as e: print(f"Error creating MySQL engine: {e}") raise e # ---------------------------------------------------------------------------------------------------- def read_mysql_table(self, engine, table_name, brand): query = self._get_query(table_name, brand) with engine.connect() as conn: dataframe = pd.read_sql_query(query, conn) # Convert all column names to lowercase if not dataframe.columns = dataframe.columns.str.lower() return dataframe # ---------------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------------- class OutputParser(ResponseParser): def __init__(self, context) -> None: super().__init__(context) def parse(self, result): return result # ---------------------------------------------------------------------------------------------------- if __name__ == "__main__": # query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents" # query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes" # query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? " # query = "how many users do we have with 0 experience level?" query = "select song content_type that have difficulty range of 0-3?" # # dataframe_path = "data/recent_comment_test.csv" # # dataframe1 = pd.read_csv(dataframe_path) # sq = SmartQuery() session = sq.snowflake_connection() dataframe = sq.read_snowflake_table(session, table_name="contents", brand="drumeo") result = sq.perform_query_on_dataframes(query, dataframe) print(result)