Musolyze / SmartQuery_GC.py
Danialebrat's picture
adding codes and files
fe8a467
# this class will use env variables to read secrets from google cloud
from pandasai import Agent
from pandasai import SmartDataframe, SmartDatalake
from pandasai.responses.response_parser import ResponseParser
from pandasai.llm.openai import OpenAI
from pandasai.responses.streamlit_response import StreamlitResponse
import pymysql
from pandasai.connectors import PandasConnector
from snowflake.snowpark import Session
import json
import pandas as pd
from sqlalchemy import create_engine
import os
import streamlit as st
from dotenv import load_dotenv
load_dotenv()
import datetime
# -----------------------------------------------------------------------
key = os.environ.get("PANDASAI_API_KEY")
os.environ['PANDASAI_API_KEY'] = key
# openai_llm = OpenAI(api_key=os.environ.get("OPENAI_API"))
openai_llm = OpenAI(
api_token=os.environ.get("OPENAI_API")
)
# -----------------------------------------------------------------------
class SmartQuery:
"""
class for interacting with dataframes using Natural Language
"""
def __init__(self):
with open("table_config.json", "r") as f:
self.config = json.load(f)
def perform_query_on_dataframes(self, query, *dataframes, response_format=None):
"""
Performs a user-defined query on given pandas DataFrames using PandasAI.
Parameters:
- query (str): The user's query or instruction.
- *dataframes (pd.DataFrame): Any number of pandas DataFrames.
Returns:
- The result of the query executed by PandasAI.
"""
dataframe_list = list(dataframes)
num_dataframes = len(dataframe_list)
config = {"llm": openai_llm, "verbose": True, "security": "none", "response_parser": OutputParser}
if num_dataframes == 1:
result = self.query_single_dataframe(query, dataframe_list[0], config)
else:
result = self.query_multiple_dataframes(query, dataframe_list, config)
return result
def query_single_dataframe(self, query, dataframe, config):
agent = Agent(dataframe, config=config)
response = agent.chat(query)
return response
def query_multiple_dataframes(self, query, dataframe_list, config):
agent = SmartDatalake(dataframe_list, config=config)
response = agent.chat(query)
return response
# -----------------------------------------------------------------------
def snowflake_connection(self):
"""
setting snowflake connection
:return:
"""
conn = {
"user": os.environ.get("snowflake_user"),
"password": os.environ.get("snowflake_password"),
"account": os.environ.get("snowflake_account"),
"role": os.environ.get("snowflake_role"),
"database": os.environ.get("snowflake_database"),
"warehouse": os.environ.get("snowflake_warehouse"),
"schema": os.environ.get("snowflake_schema")
}
try:
session = Session.builder.configs(conn).create()
return session
except Exception as e:
print(f"Error creating Snowflake session: {e}")
raise e
# ----------------------------------------------------------------------------------------------------
def read_snowflake_table(self, session, table_name, brand):
"""
reading tables from snowflake
:param dataframe:
:return:
"""
query = self._get_query(table_name, brand)
# Connect to Snowflake
try:
dataframe = session.sql(query).to_pandas()
dataframe.columns = dataframe.columns.str.lower()
print(f"reading content table successfully")
return dataframe
except Exception as e:
print(f"Error in reading table: {e}")
# ----------------------------------------------------------------------------------------------------
def _get_query(self, table_name: str, brand: str) -> str:
# Retrieve the base query template for the given table name
base_query = self.config[table_name]["query"]
# Insert the brand condition into the query
query = base_query.format(brand=brand.lower())
return query
# ----------------------------------------------------------------------------------------------------
def mysql_connection(self):
# Setting up the MySQL connection parameters
user = os.environ.get("mysql_user")
password = os.environ.get("mysql_password")
host = os.environ.get("mysql_source")
database = os.environ.get("mysql_schema")
try:
engine = create_engine(f"mysql+pymysql://{user}:{password}@{host}/{database}")
return engine
except Exception as e:
print(f"Error creating MySQL engine: {e}")
raise e
# ----------------------------------------------------------------------------------------------------
def read_mysql_table(self, engine, table_name, brand):
query = self._get_query(table_name, brand)
with engine.connect() as conn:
dataframe = pd.read_sql_query(query, conn)
# Convert all column names to lowercase if not
dataframe.columns = dataframe.columns.str.lower()
return dataframe
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
class OutputParser(ResponseParser):
def __init__(self, context) -> None:
super().__init__(context)
def parse(self, result):
return result
# ----------------------------------------------------------------------------------------------------
if __name__ == "__main__":
# query_multi = "get top 5 contents that had the most interactions and their 'content_type' is 'song'. Also include the number of interaction for these contents"
# query = "select the comments that was on 'pack-bundle-lesson' content_type and have more than 10 likes"
# query2 = "what is the number of likes, content_title and content_description for the content that received the most comments? "
# query = "how many users do we have with 0 experience level?"
query = "select song content_type that have difficulty range of 0-3?"
#
# dataframe_path = "data/recent_comment_test.csv"
#
# dataframe1 = pd.read_csv(dataframe_path)
#
sq = SmartQuery()
session = sq.snowflake_connection()
dataframe = sq.read_snowflake_table(session, table_name="contents", brand="drumeo")
result = sq.perform_query_on_dataframes(query, dataframe)
print(result)