from dotenv import load_dotenv import os import streamlit as st from langchain_aws import BedrockEmbeddings from langchain_core.vectorstores import InMemoryVectorStore from langchain.chat_models import init_chat_model from langchain_core.documents import Document from typing_extensions import List, Dict from langchain_text_splitters import RecursiveCharacterTextSplitter from langgraph.graph import START, StateGraph, END from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader from langgraph.graph import MessagesState from langchain_core.tools import tool from langchain_core.messages import SystemMessage from langgraph.prebuilt import ToolNode, tools_condition from langchain_milvus import Milvus from utils import extract_text_from_content from logging_config import setup_logger from load_vector_db import init_vector_db from graph_basic import init_graph import time logger = setup_logger(__name__) # Load environment variables load_dotenv(override=True) # Set AWS credentials from environment variables os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("aws_access_key_id") os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("aws_secret_access_key") os.environ["AWS_SESSION_TOKEN"] = os.environ.get("aws_session_token") os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION") print(os.environ["AWS_ACCESS_KEY_ID"]) # Initialize session state variables if they don't exist if "messages" not in st.session_state: st.session_state.messages = [] if "initialized" not in st.session_state: st.session_state.initialized = False def run_graph(graph, input_message: str): """Run the graph with the input message.""" # Create a new state st.conversation_history.append( { "role": "user", "content": input_message } ) input_message_formatted = { "messages": st.conversation_history } # Stream responses full_response = "" response_chunks = [] values = [] start = time.time() time_to_start_streaming = None for mode, mode_chunk in graph.stream( input_message_formatted, stream_mode=["messages", "values"], ): if mode == "values": values.append(mode_chunk) elif mode == "messages": message, metadata = mode_chunk # if metadata["langgraph_node"] == "query_or_respond": # logger.info(f"message.tool_calls: {message.tool_calls}") # if not message.tool_calls: # content = message.content # logger.info(f"query_or_respond content type: {isinstance(content, str)}") # logger.info(f"query_or_respond content: {content}") # if isinstance(content, str): # chunk_text = content # # chunk_text = extract_text_from_content(content) # if chunk_text: # response_chunks.append(chunk_text) # yield chunk_text, values if metadata["langgraph_node"] == "generate": if hasattr(message, 'content'): if time_to_start_streaming is None: time_to_start_streaming = time.time() - start logger.info(f"Time taken to start streaming: {time_to_start_streaming} seconds") content = message.content # Extract text depending on content format chunk_text = extract_text_from_content(content) # print(f"Chunk text: {chunk_text}") if chunk_text: response_chunks.append(chunk_text) yield chunk_text, values full_response = ''.join(response_chunks) logger.info(f"Time taken for complete generation: {time.time() - start} seconds") # print(f"Full text: {full_response}") # print(f"full values: {values}") st.conversation_history.append({ "role": "assistant", "content": full_response }) return full_response, values # Main Streamlit UI st.title("PDF Question-Answering Chat") # Initialize the app if not already done if not st.session_state.initialized: try: app_components = init_graph() st.session_state.app_components = app_components st.session_state.initialized = True st.conversation_history = [] except Exception as e: st.error(f"Error initializing app: {e}") st.stop() # Display chat messages from history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Accept user input if prompt := st.chat_input("Ask a question about your PDFs"): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message in chat message container with st.chat_message("user"): st.markdown(prompt) # Display assistant response in chat message container with st.chat_message("assistant"): message_placeholder = st.empty() try: # Stream the response full_response = "" sources_text = "\n\n" values = {} for chunk, values in run_graph(st.session_state.app_components["graph"], prompt): if chunk: # Only process non-empty chunks # print(f"Chunk: {chunk}") full_response += chunk message_placeholder.markdown(full_response + "▌") try: values = values[-1] logger.info(f"values keys: {values.keys()}") logger.info(f"'context' in values: { 'context' in values }") # print(f"values: {values}") if 'context' in values: pages_dict = {} for i in values['context']: key = (i.metadata['source'], i.metadata['page']) if key not in pages_dict: pages_dict[key] = { "source": i.metadata['source'], "page": i.metadata['page'], "page_label": i.metadata['page_label'], "relevance_score" : i.metadata["relevance_score"], } sources_text += f"Source: {i.metadata['source']}, Page: {i.metadata['page']}, Page Label: {i.metadata['page_label']}, Relevance Score : {i.metadata["relevance_score"]}\n\n" except Exception as e: logger.error(f"Error processing values: {e}") sources_text = "No sources found for the response." if full_response == "": full_response = "Could not find any relevant information in the documents." st.conversation_history.append( { "role": "assistant", "content": full_response } ) else: full_response += "\n\n" + sources_text message_placeholder.markdown(full_response) logger.info(f"Full response: {full_response}") # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": full_response}) except Exception as e: import traceback error_msg = f"Error processing your query: {str(e)}\n\n```\n{traceback.format_exc()}\n```" message_placeholder.error(error_msg) st.session_state.messages.append({"role": "assistant", "content": error_msg})