Tyrex_Chatbot / app.py
pritmanvar-bacancy's picture
initial commit
accb514 verified
import re
import streamlit as st
import random
from app_config import SYSTEM_PROMPT, NLP_MODEL_NAME, NUMBER_OF_VECTORS_FOR_RAG, NLP_MODEL_TEMPERATURE, NLP_MODEL_MAX_TOKENS, VECTOR_MAX_TOKENS, SLOT_ID_PATTERN, INVOICE_NUM_PATTERN
from utils.functions import get_vectorstore, tiktoken_len
from langchain.memory import ConversationSummaryBufferMemory
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.agents.agent_types import AgentType
from dotenv import load_dotenv
from pathlib import Path
import os
from tools.tools import response_generator,cancle_ongoing_process,cancle_slot,get_invoice,get_slot_details,schedule_slot,update_slot,price_estimation
import session_manager
env_path = Path('.') / '.env'
load_dotenv(dotenv_path=env_path)
st.markdown(
"""
<style>
.st-emotion-cache-janbn0 {
flex-direction: row-reverse;
text-align: right;
}
.st-emotion-cache-1ec2a3d{
display: none;
}
</style>
""",
unsafe_allow_html=True,
)
# Intialize chat history
print("SYSTEM MESSAGE")
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "system", "content": SYSTEM_PROMPT}]
print("SYSTEM MODEL")
if "llm" not in st.session_state:
st.session_state.llm = ChatGroq(temperature=NLP_MODEL_TEMPERATURE, groq_api_key=str(os.getenv('GROQ_API_KEY')), model_name=NLP_MODEL_NAME)
print("rag")
if "rag_memory" not in st.session_state:
st.session_state.rag_memory = ConversationSummaryBufferMemory(
llm=st.session_state.llm, max_token_limit=NLP_MODEL_MAX_TOKENS - tiktoken_len(SYSTEM_PROMPT) - VECTOR_MAX_TOKENS*NUMBER_OF_VECTORS_FOR_RAG)
print("retrival")
if "retriever" not in st.session_state:
st.session_state.retriever = get_vectorstore().as_retriever(k=NUMBER_OF_VECTORS_FOR_RAG)
print("agent_history")
if "agent_history" not in st.session_state:
st.session_state.agent_history = {}
print("next agent")
if "next_agent" not in st.session_state:
st.session_state.next_agent = "general_agent"
print("last_query")
if "last_query" not in st.session_state:
st.session_state.last_query = ""
print("last_tool")
if "last_tool" not in st.session_state:
st.session_state.last_tool = ""
print("agent")
session_manager.set_session_state(st.session_state)
# intilize all tools
if "agents" not in st.session_state:
st.session_state.agents = {"general_agent": initialize_agent(tools=[response_generator, schedule_slot, cancle_ongoing_process, cancle_slot, update_slot, get_slot_details, get_invoice, price_estimation],
llm=st.session_state.llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=3,
handle_parsing_errors=True,
),
"slot_booking_agent": initialize_agent(tools=[schedule_slot, cancle_ongoing_process],
llm=st.session_state.llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=3,
handle_parsing_errors=True),
"slot_canclelation_agent": initialize_agent(tools=[cancle_slot, cancle_ongoing_process],
llm=st.session_state.llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=3,
handle_parsing_errors=True),
"slot_update_agent": initialize_agent(tools=[update_slot, cancle_ongoing_process],
llm=st.session_state.llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=3,
handle_parsing_errors=True),
"get_invoice_agent": initialize_agent(tools=[get_invoice, cancle_ongoing_process],
llm=st.session_state.llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=3,
handle_parsing_errors=True),
"price_estimation_agent": initialize_agent(tools=[price_estimation, cancle_ongoing_process],
llm=st.session_state.llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=3,
handle_parsing_errors=True),
}
if img_file_buffer := st.file_uploader('Upload a Tyre image', type=['png', 'jpg', 'jpeg'],accept_multiple_files=False,label_visibility="hidden"):
with open(os.path.join("Tyre.png"), "wb") as f:
f.write(img_file_buffer.getbuffer())
print("container")
# Display chat messages from history
container = st.container(height=600)
for message in st.session_state.messages:
if message["role"] != "system":
with container.chat_message(message["role"]):
if message['type'] == "table":
st.dataframe(message['content'].set_index(message['content'].columns[0]))
elif message['type'] == "html":
st.markdown(message['content'], unsafe_allow_html=True)
else:
st.write(message["content"])
# When user gives input
if prompt := st.chat_input("Enter your query here... "):
with container.chat_message("user"):
st.write(prompt)
st.session_state.messages.append({"role": "user", "content": prompt,"type":"string"})
st.session_state.last_query = prompt
with container.chat_message("assistant"):
current_conversation = """"""
if st.session_state.next_agent != "general_agent" and st.session_state.next_agent in st.session_state.agent_history:
for message in st.session_state.agent_history[st.session_state.next_agent]:
if message['role'] == 'user':
current_conversation += f"""user: {message['content']}\n"""
if message['role'] == 'assistant':
current_conversation += f"""ai: {message['content']}\n"""
current_conversation += f"""user: {prompt}\n"""
print("***************************************** HISTORY ********************************************")
print(st.session_state.agent_history)
print("****************************************** Messages ******************************************")
print("messages", current_conversation)
print()
print()
response = st.session_state.agents[st.session_state.next_agent](current_conversation)['output']
print("******************************************************** Response ********************************************************")
print("MY RESPONSE IS:", response)
if st.session_state.last_tool == "get-invoice-tool":
st.session_state.messages.append({"role": "assistant", "content": response, "type": "html"})
st.markdown(response, unsafe_allow_html=True)
elif st.session_state.last_tool == "slot-fetching-tool":
st.dataframe(response['df'].set_index(response['df'].columns[0]))
st.write(f"Congratulations!!! You have recycled ***{response['num_of_tyres']}*** tyres.")
st.session_state.messages.append({"role": "assistant", "content": response['df'], "type": "table"})
st.session_state.messages.append({"role": "assistant", "content": f"Congratulations!!! You have recycled ***{response['num_of_tyres']}*** tyres.", "type": "string"})
else:
st.write(response)
st.session_state.messages.append({"role": "assistant", "content": response, "type": "string"})
if st.session_state.last_tool == "question-answer-tool":
st.session_state.rag_memory.save_context({'input': prompt}, {'output': response})