|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
import streamlit as st |
|
|
from langchain_aws import BedrockEmbeddings |
|
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain_core.documents import Document |
|
|
from typing_extensions import List, Dict |
|
|
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
|
|
|
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader |
|
|
from langgraph.graph import MessagesState |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import SystemMessage |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langchain_milvus import Milvus |
|
|
from utils import extract_text_from_content |
|
|
from logging_config import setup_logger |
|
|
from load_vector_db import init_vector_db |
|
|
from graph_basic import init_graph |
|
|
import time |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv(override=True) |
|
|
|
|
|
|
|
|
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("aws_access_key_id") |
|
|
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("aws_secret_access_key") |
|
|
os.environ["AWS_SESSION_TOKEN"] = os.environ.get("aws_session_token") |
|
|
os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION") |
|
|
print(os.environ["AWS_ACCESS_KEY_ID"]) |
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
if "initialized" not in st.session_state: |
|
|
st.session_state.initialized = False |
|
|
|
|
|
|
|
|
|
|
|
def run_graph(graph, input_message: str): |
|
|
"""Run the graph with the input message.""" |
|
|
|
|
|
|
|
|
|
|
|
st.conversation_history.append( |
|
|
{ |
|
|
"role": "user", |
|
|
"content": input_message |
|
|
} |
|
|
) |
|
|
|
|
|
input_message_formatted = { |
|
|
"messages": st.conversation_history |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_response = "" |
|
|
response_chunks = [] |
|
|
values = [] |
|
|
|
|
|
start = time.time() |
|
|
time_to_start_streaming = None |
|
|
for mode, mode_chunk in graph.stream( |
|
|
input_message_formatted, |
|
|
stream_mode=["messages", "values"], |
|
|
): |
|
|
if mode == "values": |
|
|
values.append(mode_chunk) |
|
|
elif mode == "messages": |
|
|
message, metadata = mode_chunk |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if metadata["langgraph_node"] == "generate": |
|
|
if hasattr(message, 'content'): |
|
|
if time_to_start_streaming is None: |
|
|
time_to_start_streaming = time.time() - start |
|
|
logger.info(f"Time taken to start streaming: {time_to_start_streaming} seconds") |
|
|
content = message.content |
|
|
|
|
|
chunk_text = extract_text_from_content(content) |
|
|
|
|
|
if chunk_text: |
|
|
response_chunks.append(chunk_text) |
|
|
yield chunk_text, values |
|
|
full_response = ''.join(response_chunks) |
|
|
|
|
|
logger.info(f"Time taken for complete generation: {time.time() - start} seconds") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.conversation_history.append({ |
|
|
"role": "assistant", |
|
|
"content": full_response |
|
|
}) |
|
|
return full_response, values |
|
|
|
|
|
|
|
|
st.title("PDF Question-Answering Chat") |
|
|
|
|
|
|
|
|
if not st.session_state.initialized: |
|
|
try: |
|
|
app_components = init_graph() |
|
|
st.session_state.app_components = app_components |
|
|
st.session_state.initialized = True |
|
|
st.conversation_history = [] |
|
|
except Exception as e: |
|
|
st.error(f"Error initializing app: {e}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask a question about your PDFs"): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
message_placeholder = st.empty() |
|
|
|
|
|
try: |
|
|
|
|
|
full_response = "" |
|
|
sources_text = "\n\n" |
|
|
values = {} |
|
|
for chunk, values in run_graph(st.session_state.app_components["graph"], prompt): |
|
|
if chunk: |
|
|
|
|
|
full_response += chunk |
|
|
message_placeholder.markdown(full_response + "▌") |
|
|
|
|
|
|
|
|
try: |
|
|
values = values[-1] |
|
|
logger.info(f"values keys: {values.keys()}") |
|
|
logger.info(f"'context' in values: { 'context' in values }") |
|
|
|
|
|
|
|
|
if 'context' in values: |
|
|
pages_dict = {} |
|
|
|
|
|
for i in values['context']: |
|
|
key = (i.metadata['source'], i.metadata['page']) |
|
|
if key not in pages_dict: |
|
|
pages_dict[key] = { |
|
|
"source": i.metadata['source'], |
|
|
"page": i.metadata['page'], |
|
|
"page_label": i.metadata['page_label'], |
|
|
"relevance_score" : i.metadata["relevance_score"], |
|
|
} |
|
|
sources_text += f"Source: {i.metadata['source']}, Page: {i.metadata['page']}, Page Label: {i.metadata['page_label']}, Relevance Score : {i.metadata["relevance_score"]}\n\n" |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing values: {e}") |
|
|
sources_text = "No sources found for the response." |
|
|
|
|
|
if full_response == "": |
|
|
full_response = "Could not find any relevant information in the documents." |
|
|
|
|
|
st.conversation_history.append( |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": full_response |
|
|
} |
|
|
) |
|
|
else: |
|
|
full_response += "\n\n" + sources_text |
|
|
|
|
|
|
|
|
message_placeholder.markdown(full_response) |
|
|
logger.info(f"Full response: {full_response}") |
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error processing your query: {str(e)}\n\n```\n{traceback.format_exc()}\n```" |
|
|
message_placeholder.error(error_msg) |
|
|
st.session_state.messages.append({"role": "assistant", "content": error_msg}) |