multimodal_rag_chatbot / app_basic.py
vamsidharmuthireddy's picture
Upload 6 files
b3f819d verified
from dotenv import load_dotenv
import os
import streamlit as st
from langchain_aws import BedrockEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from typing_extensions import List, Dict
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph, END
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langgraph.graph import MessagesState
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_milvus import Milvus
from utils import extract_text_from_content
from logging_config import setup_logger
from load_vector_db import init_vector_db
from graph_basic import init_graph
import time
logger = setup_logger(__name__)
# Load environment variables
load_dotenv(override=True)
# Set AWS credentials from environment variables
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("aws_access_key_id")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("aws_secret_access_key")
os.environ["AWS_SESSION_TOKEN"] = os.environ.get("aws_session_token")
os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION")
print(os.environ["AWS_ACCESS_KEY_ID"])
# Initialize session state variables if they don't exist
if "messages" not in st.session_state:
st.session_state.messages = []
if "initialized" not in st.session_state:
st.session_state.initialized = False
def run_graph(graph, input_message: str):
"""Run the graph with the input message."""
# Create a new state
st.conversation_history.append(
{
"role": "user",
"content": input_message
}
)
input_message_formatted = {
"messages": st.conversation_history
}
# Stream responses
full_response = ""
response_chunks = []
values = []
start = time.time()
time_to_start_streaming = None
for mode, mode_chunk in graph.stream(
input_message_formatted,
stream_mode=["messages", "values"],
):
if mode == "values":
values.append(mode_chunk)
elif mode == "messages":
message, metadata = mode_chunk
# if metadata["langgraph_node"] == "query_or_respond":
# logger.info(f"message.tool_calls: {message.tool_calls}")
# if not message.tool_calls:
# content = message.content
# logger.info(f"query_or_respond content type: {isinstance(content, str)}")
# logger.info(f"query_or_respond content: {content}")
# if isinstance(content, str):
# chunk_text = content
# # chunk_text = extract_text_from_content(content)
# if chunk_text:
# response_chunks.append(chunk_text)
# yield chunk_text, values
if metadata["langgraph_node"] == "generate":
if hasattr(message, 'content'):
if time_to_start_streaming is None:
time_to_start_streaming = time.time() - start
logger.info(f"Time taken to start streaming: {time_to_start_streaming} seconds")
content = message.content
# Extract text depending on content format
chunk_text = extract_text_from_content(content)
# print(f"Chunk text: {chunk_text}")
if chunk_text:
response_chunks.append(chunk_text)
yield chunk_text, values
full_response = ''.join(response_chunks)
logger.info(f"Time taken for complete generation: {time.time() - start} seconds")
# print(f"Full text: {full_response}")
# print(f"full values: {values}")
st.conversation_history.append({
"role": "assistant",
"content": full_response
})
return full_response, values
# Main Streamlit UI
st.title("PDF Question-Answering Chat")
# Initialize the app if not already done
if not st.session_state.initialized:
try:
app_components = init_graph()
st.session_state.app_components = app_components
st.session_state.initialized = True
st.conversation_history = []
except Exception as e:
st.error(f"Error initializing app: {e}")
st.stop()
# Display chat messages from history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Ask a question about your PDFs"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
message_placeholder = st.empty()
try:
# Stream the response
full_response = ""
sources_text = "\n\n"
values = {}
for chunk, values in run_graph(st.session_state.app_components["graph"], prompt):
if chunk: # Only process non-empty chunks
# print(f"Chunk: {chunk}")
full_response += chunk
message_placeholder.markdown(full_response + "▌")
try:
values = values[-1]
logger.info(f"values keys: {values.keys()}")
logger.info(f"'context' in values: { 'context' in values }")
# print(f"values: {values}")
if 'context' in values:
pages_dict = {}
for i in values['context']:
key = (i.metadata['source'], i.metadata['page'])
if key not in pages_dict:
pages_dict[key] = {
"source": i.metadata['source'],
"page": i.metadata['page'],
"page_label": i.metadata['page_label'],
"relevance_score" : i.metadata["relevance_score"],
}
sources_text += f"Source: {i.metadata['source']}, Page: {i.metadata['page']}, Page Label: {i.metadata['page_label']}, Relevance Score : {i.metadata["relevance_score"]}\n\n"
except Exception as e:
logger.error(f"Error processing values: {e}")
sources_text = "No sources found for the response."
if full_response == "":
full_response = "Could not find any relevant information in the documents."
st.conversation_history.append(
{
"role": "assistant",
"content": full_response
}
)
else:
full_response += "\n\n" + sources_text
message_placeholder.markdown(full_response)
logger.info(f"Full response: {full_response}")
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
except Exception as e:
import traceback
error_msg = f"Error processing your query: {str(e)}\n\n```\n{traceback.format_exc()}\n```"
message_placeholder.error(error_msg)
st.session_state.messages.append({"role": "assistant", "content": error_msg})