visualquery / app.py
binaychandra's picture
Modified openai keys and files
47d3d15
raw
history blame
2.66 kB
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 25 18:00:03 2024
@author: MK529XT
"""
import streamlit as st
import string
import random
from langchain_helper import get_few_shot_db_chain
import plotly.figure_factory as ff
import numpy as np
#st.set_page_config(layout="wide")
# CSS for styling
st.markdown("""
<style>
.title {
text-align: center;
outline: solid yellow;
font-size: 20px;
font-family: Arial, Helvetica, sans-serif;
color: #FFFFFF;
padding-top: 5px;
padding-bottom: 5px;
#border-bottom: 2px solid #FFFF00;
background-color: #050201;
}
</style>
""", unsafe_allow_html=True)
# Title section
st.markdown("<h1 class='title'>E-Commerce Analysis</h1>", unsafe_allow_html=True)
with st.chat_message("assistant"):
st.write("Hello πŸ‘‹ How can I help you today?")
def random_string() -> dict:
try:
response_dict = get_few_shot_db_chain(st.session_state["chat_input"])
except Exception as e:
response_dict = {
"result_df" : None,
"sql_command" : None,
"response" : f"LLM ran into issues : {str(e)}",
"input" : st.session_state["chat_input"],
"graph_data" : None
}
return response_dict
def chat_actions():
st.session_state["chat_history"].append(
{
"role": "user",
"content": st.session_state["chat_input"],
}
)
st.session_state["chat_history"].append(
{
"role": "assistant",
"content": random_string(),
},
)
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
st.chat_input("Enter your question", on_submit=chat_actions, key="chat_input")
for i in st.session_state["chat_history"]:
with st.chat_message(name=i["role"]):
print(type(i["content"]))
if isinstance(i["content"], str):
st.write(i["content"])
# When this is llm or bot response #
elif isinstance(i["content"], dict):
#st.info(i["content"]["sql_command"])
st.write(i["content"]["response"])
result_df = i["content"]["result_df"]
if i['content']["graph_data"] is not None:
st.plotly_chart(i['content']["graph_data"], use_container_width=True)
elif (result_df is not None) and ((result_df.shape[0] > 1) and (result_df.shape[1] > 1)) :
st.plotly_chart(ff.create_table(result_df), use_container_width=True)