Spaces:
Sleeping
Sleeping
| import re | |
| import streamlit as st | |
| import random | |
| from app_config import SYSTEM_PROMPT, NLP_MODEL_NAME, NUMBER_OF_VECTORS_FOR_RAG, NLP_MODEL_TEMPERATURE, NLP_MODEL_MAX_TOKENS, VECTOR_MAX_TOKENS, SLOT_ID_PATTERN, INVOICE_NUM_PATTERN | |
| from utils.functions import get_vectorstore, tiktoken_len | |
| from langchain.memory import ConversationSummaryBufferMemory | |
| from langchain_groq import ChatGroq | |
| from langchain.agents import initialize_agent | |
| from langchain.agents.agent_types import AgentType | |
| from dotenv import load_dotenv | |
| from pathlib import Path | |
| import os | |
| from tools.tools import response_generator,cancle_ongoing_process,cancle_slot,get_invoice,get_slot_details,schedule_slot,update_slot,price_estimation | |
| import session_manager | |
| env_path = Path('.') / '.env' | |
| load_dotenv(dotenv_path=env_path) | |
| st.markdown( | |
| """ | |
| <style> | |
| .st-emotion-cache-janbn0 { | |
| flex-direction: row-reverse; | |
| text-align: right; | |
| } | |
| .st-emotion-cache-1ec2a3d{ | |
| display: none; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Intialize chat history | |
| print("SYSTEM MESSAGE") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| print("SYSTEM MODEL") | |
| if "llm" not in st.session_state: | |
| st.session_state.llm = ChatGroq(temperature=NLP_MODEL_TEMPERATURE, groq_api_key=str(os.getenv('GROQ_API_KEY')), model_name=NLP_MODEL_NAME) | |
| print("rag") | |
| if "rag_memory" not in st.session_state: | |
| st.session_state.rag_memory = ConversationSummaryBufferMemory( | |
| llm=st.session_state.llm, max_token_limit=NLP_MODEL_MAX_TOKENS - tiktoken_len(SYSTEM_PROMPT) - VECTOR_MAX_TOKENS*NUMBER_OF_VECTORS_FOR_RAG) | |
| print("retrival") | |
| if "retriever" not in st.session_state: | |
| st.session_state.retriever = get_vectorstore().as_retriever(k=NUMBER_OF_VECTORS_FOR_RAG) | |
| print("agent_history") | |
| if "agent_history" not in st.session_state: | |
| st.session_state.agent_history = {} | |
| print("next agent") | |
| if "next_agent" not in st.session_state: | |
| st.session_state.next_agent = "general_agent" | |
| print("last_query") | |
| if "last_query" not in st.session_state: | |
| st.session_state.last_query = "" | |
| print("last_tool") | |
| if "last_tool" not in st.session_state: | |
| st.session_state.last_tool = "" | |
| print("agent") | |
| session_manager.set_session_state(st.session_state) | |
| # intilize all tools | |
| if "agents" not in st.session_state: | |
| st.session_state.agents = {"general_agent": initialize_agent(tools=[response_generator, schedule_slot, cancle_ongoing_process, cancle_slot, update_slot, get_slot_details, get_invoice, price_estimation], | |
| llm=st.session_state.llm, | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=3, | |
| handle_parsing_errors=True, | |
| ), | |
| "slot_booking_agent": initialize_agent(tools=[schedule_slot, cancle_ongoing_process], | |
| llm=st.session_state.llm, | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=3, | |
| handle_parsing_errors=True), | |
| "slot_canclelation_agent": initialize_agent(tools=[cancle_slot, cancle_ongoing_process], | |
| llm=st.session_state.llm, | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=3, | |
| handle_parsing_errors=True), | |
| "slot_update_agent": initialize_agent(tools=[update_slot, cancle_ongoing_process], | |
| llm=st.session_state.llm, | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=3, | |
| handle_parsing_errors=True), | |
| "get_invoice_agent": initialize_agent(tools=[get_invoice, cancle_ongoing_process], | |
| llm=st.session_state.llm, | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=3, | |
| handle_parsing_errors=True), | |
| "price_estimation_agent": initialize_agent(tools=[price_estimation, cancle_ongoing_process], | |
| llm=st.session_state.llm, | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=3, | |
| handle_parsing_errors=True), | |
| } | |
| if img_file_buffer := st.file_uploader('Upload a Tyre image', type=['png', 'jpg', 'jpeg'],accept_multiple_files=False,label_visibility="hidden"): | |
| with open(os.path.join("Tyre.png"), "wb") as f: | |
| f.write(img_file_buffer.getbuffer()) | |
| print("container") | |
| # Display chat messages from history | |
| container = st.container(height=600) | |
| for message in st.session_state.messages: | |
| if message["role"] != "system": | |
| with container.chat_message(message["role"]): | |
| if message['type'] == "table": | |
| st.dataframe(message['content'].set_index(message['content'].columns[0])) | |
| elif message['type'] == "html": | |
| st.markdown(message['content'], unsafe_allow_html=True) | |
| else: | |
| st.write(message["content"]) | |
| # When user gives input | |
| if prompt := st.chat_input("Enter your query here... "): | |
| with container.chat_message("user"): | |
| st.write(prompt) | |
| st.session_state.messages.append({"role": "user", "content": prompt,"type":"string"}) | |
| st.session_state.last_query = prompt | |
| with container.chat_message("assistant"): | |
| current_conversation = """""" | |
| if st.session_state.next_agent != "general_agent" and st.session_state.next_agent in st.session_state.agent_history: | |
| for message in st.session_state.agent_history[st.session_state.next_agent]: | |
| if message['role'] == 'user': | |
| current_conversation += f"""user: {message['content']}\n""" | |
| if message['role'] == 'assistant': | |
| current_conversation += f"""ai: {message['content']}\n""" | |
| current_conversation += f"""user: {prompt}\n""" | |
| print("***************************************** HISTORY ********************************************") | |
| print(st.session_state.agent_history) | |
| print("****************************************** Messages ******************************************") | |
| print("messages", current_conversation) | |
| print() | |
| print() | |
| response = st.session_state.agents[st.session_state.next_agent](current_conversation)['output'] | |
| print("******************************************************** Response ********************************************************") | |
| print("MY RESPONSE IS:", response) | |
| if st.session_state.last_tool == "get-invoice-tool": | |
| st.session_state.messages.append({"role": "assistant", "content": response, "type": "html"}) | |
| st.markdown(response, unsafe_allow_html=True) | |
| elif st.session_state.last_tool == "slot-fetching-tool": | |
| st.dataframe(response['df'].set_index(response['df'].columns[0])) | |
| st.write(f"Congratulations!!! You have recycled ***{response['num_of_tyres']}*** tyres.") | |
| st.session_state.messages.append({"role": "assistant", "content": response['df'], "type": "table"}) | |
| st.session_state.messages.append({"role": "assistant", "content": f"Congratulations!!! You have recycled ***{response['num_of_tyres']}*** tyres.", "type": "string"}) | |
| else: | |
| st.write(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response, "type": "string"}) | |
| if st.session_state.last_tool == "question-answer-tool": | |
| st.session_state.rag_memory.save_context({'input': prompt}, {'output': response}) | |