ageraustine's picture
cache agent
902afe3 verified
from langchain_core.messages import AIMessage, HumanMessage
import streamlit as st
from streamlit_chat import message
import os
from agent import ReActAgent
@st.cache_resource
def get_agent():
return ReActAgent()
agent = get_agent()
#Streamlit app layout
st.title("Bliff AI")
clear_button = st.sidebar.button("Clear Conversation", key="clear")
# Initialise session state variables
if 'generated' not in st.session_state:
st.session_state['generated'] = ['Quick and Precise Search about just anything']
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'history' not in st.session_state:
st.session_state['history'] = []
# reset everything
if clear_button:
st.session_state['generated'] = ['Quick and Precise Search about just anything']
st.session_state['past'] = []
st.session_state['history'] = []
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=100)
submit_button = st.form_submit_button(label='Send')
if submit_button and user_input:
if len(st.session_state['history']) > 0 :
history_string = ""
for chat in st.session_state['history']:
if isinstance(chat, HumanMessage):
history_string += f"Human: {chat.content}\n"
elif isinstance(chat, AIMessage):
history_string += f"AI: {chat.content}\n"
# Remove trailing newline character if present
history_string = history_string.rstrip("\n")
response = agent.run(user_input, history_string)['output']
else:
response = agent.run(user_input)['output']
st.session_state['generated'].append(response)
st.session_state['past'].append(user_input)
st.session_state['history'].extend([HumanMessage(content=user_input), AIMessage(content=response)])
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["generated"][i], key=str(i))
if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')