File size: 3,388 Bytes
adf2969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5da7f5
edcfd7e
adf2969
 
 
 
 
 
 
 
 
5e99252
52acadd
adf2969
 
 
 
 
 
edcfd7e
 
 
 
 
 
 
 
 
 
 
 
2ce7c79
 
edcfd7e
2ce7c79
c2d7c5b
 
2ce7c79
 
 
 
edcfd7e
c2d7c5b
 
2ce7c79
edcfd7e
 
2ce7c79
adf2969
47d3d15
6b99c67
 
adf2969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13590ec
a40b058
 
 
 
 
 
13590ec
 
 
 
a5da7f5
adf2969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 25 18:00:03 2024

@author: MK529XT
"""

import streamlit as st
import string
import random
from langchain_helper import get_few_shot_db_chain
import plotly.figure_factory as ff
import numpy as np

#st.set_page_config(layout="wide")

st.set_page_config(layout="wide")

# CSS for styling
st.markdown("""
<style>
        .title {
            text-align: center;
            outline: solid yellow;
            font-size: 20px;
            font-family: Arial, Helvetica, sans-serif;
            color: #FFFFFF;
            margin-top: -80px;
            padding-top: 5px;
            padding-bottom: 5px;
            #border-bottom: 2px solid #FFFF00;
            background-color: #050201;
        }
</style>
""", unsafe_allow_html=True)

st.markdown("""
<style>
    #header {
        visibility: hidden;
    }
    #main-content {
        padding-top: 0;
    }
</style>
""", unsafe_allow_html=True)

# Custom CSS for padding adjustments
st.markdown(f"""
<style>
    .chat-input-container {{
        margin-left: 80px;
        margin-right: 80px;
        # padding: 10px;
        border-radius: 5px;
    }}
    .element-container {{
        padding-bottom: 20px;
        margin-left: 80px;
        margin-right: 80px;
    }}
</style>
""", unsafe_allow_html=True)

# Title section
st.markdown("<h1 class='title'>E-Commerce Analysis</h1>", unsafe_allow_html=True)



def random_string() -> dict:
    try:
        response_dict = get_few_shot_db_chain(st.session_state["chat_input"])
    except Exception as e:
        response_dict = {
            "result_df" : None,
            "sql_command" : None,
            "response" : f"LLM ran into issues : {str(e)}",
            "input" : st.session_state["chat_input"],
            "graph_data" : None
        } 
    return response_dict

def chat_actions():
    st.session_state["chat_history"].append(
        {
            "role": "user", 
            "content": st.session_state["chat_input"],
        }
    )

    st.session_state["chat_history"].append(
        {
            "role": "assistant",
            "content": random_string(),
        },  
    )


if "chat_history" not in st.session_state:
    st.session_state["chat_history"] = []

# Example Section
st.markdown("""
###### Try some example questions:
 - Show me the top 10 products by sales.
 - What are the average sales by region?
 - Give me the monthly sales trend for the last year.
""")

with st.chat_message("assistant"):
    st.write("Hello 👋 How can I help you today?")

st.chat_input("Enter your question", on_submit=chat_actions, key="chat_input")

for i in st.session_state["chat_history"]:
    with st.chat_message(name=i["role"]):
        print(type(i["content"]))
        if isinstance(i["content"], str):
            st.write(i["content"])
        
        # When this is llm or bot response #
        elif isinstance(i["content"], dict):
            #st.info(i["content"]["sql_command"])
            st.write(i["content"]["response"])
            result_df = i["content"]["result_df"]
            if i['content']["graph_data"] is not None:
                st.plotly_chart(i['content']["graph_data"], use_container_width=True)
            elif (result_df is not None) and ((result_df.shape[0] > 1) and (result_df.shape[1] > 1)) :
                st.plotly_chart(ff.create_table(result_df), use_container_width=True)