# -*- coding: utf-8 -*-
"""
Created on Thu Apr 25 18:00:03 2024
@author: MK529XT
"""
import streamlit as st
import string
import random
from langchain_helper import get_few_shot_db_chain
import plotly.figure_factory as ff
import numpy as np
#st.set_page_config(layout="wide")
st.set_page_config(layout="wide")
# CSS for styling
st.markdown("""
""", unsafe_allow_html=True)
st.markdown("""
""", unsafe_allow_html=True)
# Title section
st.markdown("
E-Commerce Analysis
", unsafe_allow_html=True)
# Custom CSS for padding adjustments
st.markdown(f"""
""", unsafe_allow_html=True)
def random_string() -> dict:
try:
response_dict = get_few_shot_db_chain(st.session_state["chat_input"])
except Exception as e:
response_dict = {
"result_df" : None,
"sql_command" : None,
"response" : f"LLM ran into issues : {str(e)}",
"input" : st.session_state["chat_input"],
"graph_data" : None
}
return response_dict
def chat_actions():
st.session_state["chat_history"].append(
{
"role": "user",
"content": st.session_state["chat_input"],
}
)
st.session_state["chat_history"].append(
{
"role": "assistant",
"content": random_string(),
},
)
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
# Example Section
st.markdown("""
###### Try some example questions:
- Show me the top 10 products by sales.
- What are the average sales by region?
- Give me the monthly sales trend for the last year.
""")
with st.chat_message("assistant"):
st.write("Hello 👋 How can I help you today?")
st.chat_input("Enter your question", on_submit=chat_actions, key="chat_input")
for i in st.session_state["chat_history"]:
with st.chat_message(name=i["role"]):
print(type(i["content"]))
if isinstance(i["content"], str):
st.write(i["content"])
# When this is llm or bot response #
elif isinstance(i["content"], dict):
#st.info(i["content"]["sql_command"])
st.write(i["content"]["response"])
result_df = i["content"]["result_df"]
if i['content']["graph_data"] is not None:
st.plotly_chart(i['content']["graph_data"], use_container_width=True)
elif (result_df is not None) and ((result_df.shape[0] > 1) and (result_df.shape[1] > 1)) :
st.plotly_chart(ff.create_table(result_df), use_container_width=True)