Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from sqlalchemy import create_engine | |
| import pandas as pd | |
| import openai | |
| import os | |
| from lida import Manager, TextGenerationConfig, llm | |
| from llmx.generators.text.openai_textgen import OpenAITextGenerator | |
| from langchain_openai import AzureChatOpenAI | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| import pandas as pd | |
| import base64 | |
| import numpy as np | |
| import matplotlib.image as mpimg | |
| from PIL import Image | |
| from langchain_core.messages import HumanMessage | |
| from langchain_openai import ChatOpenAI | |
| import base64 | |
| from utils.azure_blob import AzureBlob | |
| from langchain.output_parsers import CommaSeparatedListOutputParser | |
| from pprint import pprint | |
| azure_blob = AzureBlob(os.getenv("azure_blob_conn")) | |
| ab = azure_blob | |
| os.environ["AZURE_OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| os.environ["AZURE_OPENAI_API_VERSION"] = "2023-06-01-preview" | |
| os.environ["AZURE_OPENAI_ENDPOINT"] = os.getenv("AZURE_OPENAI_ENDPOINT") | |
| db_host = os.getenv('DB_HOST') | |
| db_name = os.getenv('DB_NAME') | |
| db_user = os.getenv('DB_USER') | |
| db_password = os.getenv('DB_PASSWORD') | |
| model = AzureChatOpenAI( | |
| deployment_name="CapSuiteGPT4omini", | |
| openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | |
| ) | |
| def choose_table(question): | |
| try: | |
| str_client_name = 'foodBeverageSample1' | |
| df_data = pd.read_parquet(ab.get_latest_parquet('landing', str_client_name, 'sale', 'sol_')) | |
| df_data2 = pd.read_parquet(ab.get_latest_parquet('landing', str_client_name, 'membership', 'mem_')) | |
| # connection_string = f'postgresql+psycopg2://{db_user}:{db_password}@{db_host}/{db_name}' | |
| # engine = create_engine(connection_string) | |
| # capsuite_ref = 'foodBeverageSample1' | |
| # model = AzureChatOpenAI( | |
| # deployment_name="CapSuiteGPT4omini", | |
| # openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | |
| # ) | |
| # table_format = """ | |
| # 1.table name:cdp_sale_order, | |
| # its columns:trxn_id,member_id,staff_id,subsidiary_name,staff_name,team_name,trxn_ref,trxn_channel,trxn_date,trxn_year,trxn_month,trxn_day,trxn_week,remark. | |
| # 2.table name:cdp_sale_order_line, | |
| # its columns:trxn_item_id,trxn_id,trxn_item_target_curr_unit_price, | |
| # trxn_item_qty,trxn_item_discount_amt,trxn_original_net_currency,trxn_date,trxn_channel,staff_name,staff_id,member_id,display_name,pord_sku,prod_category,prod_type,prod_name, | |
| # capsuite_ref. | |
| # 3.table name:cdp_stock_quant, | |
| # its columns:stock_quant_id,prod_id,location_id,stock_quantity,stock_quantity_reserved,stock_quant_create_date,capsuite_ref. | |
| # """ | |
| # prompt = ChatPromptTemplate.from_template("Base on the question:{question}," | |
| # "And the following table format:{table_format}," | |
| # "Dont write a complex query. Only select statement like 'select * from table_name'." | |
| # "Dont add any condition or filter to the query. The query should be generic and should return all the data from the table." | |
| # "Select all the columns from the table. " | |
| # "Only output one SQL Query without any other information even the '''sql''' prefix. ") | |
| # chain = ( | |
| # {"question": RunnablePassthrough(), "table_format": RunnablePassthrough()} | |
| # # {"table_format": RunnablePassthrough()} | |
| # | prompt | |
| # | model | |
| # | StrOutputParser() | |
| # ) | |
| # # query = 'select * from cdp_membership_summary;' | |
| # query = chain.invoke({"question": question, "table_format": table_format}) | |
| # query = query.replace(f"`", '') | |
| # query = query.replace(f"sql", '') | |
| # query = query.split(';')[0] + f' where capsuite_ref = \'{capsuite_ref}\';' | |
| # df_data = pd.read_sql(query, engine) | |
| # print(f'*'*50) | |
| # print(f"Query: {query}") | |
| # if 'cdp_sale_order_line' in query: | |
| df_data = pd.merge(df_data, df_data2, on='member_id', how='left',suffixes=('_sale_order_line', '_membership')) | |
| df_data['sales_amount'] = df_data['trxn_item_target_curr_unit_price'].astype(float) * df_data['trxn_item_qty'].astype(float) | |
| df_data.rename(columns={'trxn_item_target_curr_unit_price':'unit_price'}, inplace=True) | |
| df_data.rename(columns={'display_name_membership':'customer_name'}, inplace=True) | |
| df_data.rename(columns={'capsuite_ref_sale_order_line':'capsuite_ref'}, inplace=True) | |
| df_data.rename(columns={'trxn_item_qty':'sales_qty'}, inplace=True) | |
| df_data['trxn_date'] = pd.to_datetime(df_data['trxn_date']).dt.date | |
| df_data['trxn_month'] = pd.to_datetime(df_data['trxn_date']).dt.to_period('M') | |
| df_data['trxn_date'] = df_data['trxn_date'].astype(str) | |
| df_data['trxn_month'] = df_data['trxn_month'].astype(str) | |
| df_data = df_data[['trxn_item_id','trxn_id','sales_amount','unit_price','sales_qty','trxn_item_discount_amt','trxn_date','trxn_channel','staff_name','customer_name','prod_category','prod_type','prod_name','capsuite_ref','gender','age','trxn_month']] | |
| except Exception as e: | |
| print(f"Error while: {e}") | |
| finally: | |
| # engine.dispose() | |
| return df_data | |
| # Function to encode the image | |
| def encode_image(image_path): | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| def random_response(message): | |
| max_attempts = 1 # Set the maximum number of attempts | |
| attempts = 0 | |
| while attempts < max_attempts: | |
| try: | |
| df_data = choose_table(message) | |
| question = message | |
| # fill na with empty string | |
| df_data.fillna('', inplace=True) | |
| # loop columns, if column is object type, convert to string | |
| for col in df_data.columns: | |
| if df_data[col].dtype == 'object': | |
| df_data[col] = df_data[col].astype(str) | |
| text_gen = OpenAITextGenerator( | |
| provider='openai', | |
| api_type='azure', | |
| azure_endpoint= os.getenv('AZURE_OPENAI_ENDPOINT'), | |
| api_key= os.getenv('OPENAI_API_KEY'), | |
| api_version = '2023-05-15', | |
| ) | |
| lida = Manager(text_gen=text_gen) | |
| text_gen_config = TextGenerationConfig( | |
| n = 1, | |
| model = 'CapSuiteGPT35T16K', | |
| temperature=0 | |
| ) | |
| summary = lida.summarize(df_data) | |
| print(f'*'*50) | |
| pprint(f"{summary}") | |
| str_summary = str(summary) | |
| print(f'*'*50) | |
| time_now = pd.Timestamp.now() | |
| print(f'Datetime now:{time_now}') | |
| goals = lida.goals(summary, n=1, textgen_config=text_gen_config,persona=f'An data analyst of the company who want to know {question}') | |
| print(f'goals: {goals[0]}') | |
| output_parser = CommaSeparatedListOutputParser() | |
| # "Bussiness insights focus on different aspects of the data, such as sales amount,sales qty, product category, time, etc." | |
| model = AzureChatOpenAI( | |
| deployment_name="CapSuiteGPT4omini", | |
| openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | |
| temperature=0 | |
| ) | |
| str_summary = str(summary) | |
| prompt = ChatPromptTemplate.from_template("Based on the data below:{str_summary}," | |
| # "please give me the most related and useful possible question to get simple but useful insights for {question}." | |
| "The data is sales order line is every transaction of the company." | |
| "Base on the question:{question}, regenerate the output" | |
| # "Your output will be used to guide the graph generation by python using ploty, so make it simple and easier to process data." | |
| "If the original question is not metion time related varibles,do not add it." | |
| "For example: 'Goal(question='What are the sales trends by product category?visualization='bar chart of prod_category against sum(trxn_item_qty) grouped by trxn_date'. and" | |
| "'Goal(question='Who are the top customers based on transaction count?', visualization='Bar chart of customer_name vs. count(trxn_id)')" | |
| "If top in your output Goal question, default it to 10." | |
| "The visualization should align with the question and the data." | |
| "Usually, when deal with:age, show all the data." | |
| "But for other datas beside age: customer,prouct,sales,qty,etc show top 10." | |
| "Process the top data at last when put in the graph." | |
| "When ask customers,customer, it means customer_name." | |
| "When ask product, it means prod_name." | |
| "When ask category, it means prod_category." | |
| "etc, find the right column name exsiting in the data." | |
| "If the data columns is empty, please ignore the column." | |
| "Only output 1 question." | |
| "") | |
| chain = ( | |
| {"str_summary": RunnablePassthrough(),"question": RunnablePassthrough()} | |
| | prompt | |
| | model | |
| | output_parser | |
| ) | |
| insights = chain.invoke({"str_summary": str_summary, "question": question}) | |
| print(f'*'*50) | |
| print(f'insights: {insights}') | |
| # ValueError: Unsupported library. Choose from 'matplotlib', 'seaborn', 'plotly', 'bokeh', 'ggplot', 'altair'. | |
| try: | |
| temp_chart = lida.visualize(summary=summary, goal=str(insights)+"Graph heigh 800,width 1000.Set different colors to different varibles.x label rotate 60 degree,do not use the guide line", textgen_config=text_gen_config,library='matplotlib') | |
| print(f'*'*50) | |
| code = temp_chart[0].code | |
| print(f"{code}") | |
| # instructions = ["change the color of the graph to #4169E1 if there is only one variable","change the background color to white but keep the grid lines grey","set the average line for the graph to be red"] | |
| # edited_chart = lida.edit(code=code,summary=summary,instructions=instructions,library='plotly',textgen_config = text_gen_config) | |
| except Exception as e: | |
| print(f"Error while: {e}") | |
| temp_chart[0].savefig(f'chart_1.png') | |
| print(f'*'*50) | |
| print(f"Chart saved") | |
| # Path to your image | |
| image_path = "chart_1.png" | |
| # Open the image file | |
| # img = Image.open(image_path) | |
| img = mpimg.imread('chart_1.png') | |
| print(f'*'*50) | |
| print(f"Image opened") | |
| base64_image = encode_image(image_path) | |
| llm = model | |
| response = llm.invoke( | |
| [ | |
| HumanMessage( | |
| content=[ | |
| {"type": "text", "text": f"Give me some business insights base on the graph, contain exact number conclusion."}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| }, | |
| }, | |
| ] | |
| ) | |
| ] | |
| ) | |
| final_result_str = response.content | |
| return final_result_str,img | |
| except Exception as e: | |
| attempts += 1 | |
| print(f"Attempt {attempts} failed with error: {e}") | |
| if attempts >= max_attempts: | |
| return "An error occurred after multiple attempts.", None # Return an error message | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| temp_img = gr.Image( | |
| height=800 | |
| ) | |
| with gr.Column(): | |
| chat_input = gr.Textbox(placeholder="Type your message here...", label="Chat") | |
| examples = gr.Examples( | |
| examples=['Top 10 prod_cate sales', 'Top product in category Seafood','Total sales amount by product category each day using line chart','What are the top selling at product level??', | |
| 'Sales amount distribution by age','Sales amount distribution by gender', | |
| 'Top customer by sales amount' | |
| ], | |
| inputs=chat_input | |
| ) | |
| chat_output = gr.Textbox(label="Response", interactive=False) | |
| submit_button = gr.Button("生成响应") | |
| submit_button.click(fn=random_response, inputs=chat_input, outputs=[chat_output, temp_img]) | |
| demo.launch() |