| from dotenv import load_dotenv |
| import os |
| import streamlit as st |
| from langchain_aws import BedrockEmbeddings |
| from langchain_core.vectorstores import InMemoryVectorStore |
| from langchain.chat_models import init_chat_model |
| from langchain_core.documents import Document |
| from typing_extensions import List, Dict |
|
|
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langgraph.graph import START, StateGraph, END |
|
|
| from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader |
| from langgraph.graph import MessagesState |
| from langchain_core.tools import tool |
| from langchain_core.messages import SystemMessage |
| from langgraph.prebuilt import ToolNode, tools_condition |
| from langchain_milvus import Milvus |
| from utils import extract_text_from_content |
| from logging_config import setup_logger |
| from load_vector_db import init_vector_db |
| from graph_question_validation import init_graph |
| import time |
|
|
| logger = setup_logger(__name__) |
|
|
|
|
| |
| load_dotenv(override=True) |
|
|
| |
| os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("aws_access_key_id") |
| os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("aws_secret_access_key") |
| os.environ["AWS_SESSION_TOKEN"] = os.environ.get("aws_session_token") |
| os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION") |
| print(os.environ["AWS_ACCESS_KEY_ID"]) |
|
|
|
|
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
| if "initialized" not in st.session_state: |
| st.session_state.initialized = False |
|
|
|
|
|
|
| def run_graph(graph, input_message: str): |
| """Run the graph with the input message.""" |
| |
|
|
|
|
| st.conversation_history.append( |
| { |
| "role": "user", |
| "content": input_message |
| } |
| ) |
|
|
| input_message_formatted = { |
| "messages": st.conversation_history |
| } |
|
|
| |
|
|
| |
| full_response = "" |
| response_chunks = [] |
| values = [] |
| |
| start = time.time() |
| time_to_start_streaming = None |
| for mode, mode_chunk in graph.stream( |
| input_message_formatted, |
| stream_mode=["messages", "values"], |
| ): |
| if mode == "values": |
| values.append(mode_chunk) |
| elif mode == "messages": |
| message, metadata = mode_chunk |
|
|
| if metadata["langgraph_node"] in ["generate_questions", "generate_answer_to_query"]: |
|
|
| if hasattr(message, 'content'): |
| if time_to_start_streaming is None: |
| time_to_start_streaming = time.time() - start |
| logger.info(f"Time taken to start streaming: {time_to_start_streaming} seconds") |
| content = message.content |
|
|
| chunk_text = extract_text_from_content(content) |
|
|
| if chunk_text: |
| response_chunks.append(chunk_text) |
| yield chunk_text, values |
| full_response = ''.join(response_chunks) |
|
|
| logger.info(f"Time taken for complete generation: {time.time() - start} seconds") |
| |
|
|
| |
| |
| st.conversation_history.append({ |
| "role": "assistant", |
| "content": full_response |
| }) |
| return full_response, values |
|
|
| |
| st.title("PDF Question-Answering Chat") |
|
|
| |
| if not st.session_state.initialized: |
| try: |
| app_components = init_graph() |
| st.session_state.app_components = app_components |
| st.session_state.initialized = True |
| st.conversation_history = [] |
| except Exception as e: |
| st.error(f"Error initializing app: {e}") |
| st.stop() |
|
|
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
|
|
| |
| if prompt := st.chat_input("Ask a question about your PDFs"): |
| |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| |
| |
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| |
| with st.chat_message("assistant"): |
| message_placeholder = st.empty() |
| |
| try: |
| |
| full_response = "" |
| sources_text = "\n\n" |
| values = {} |
| for chunk, values in run_graph(st.session_state.app_components["graph"], prompt): |
| if chunk: |
| |
| full_response += chunk |
| message_placeholder.markdown(full_response + "▌") |
|
|
| |
| try: |
| values = values[-1] |
| logger.info(f"values keys: {values.keys()}") |
| logger.info(f"'retrieved_docs' in values: { 'retrieved_docs' in values }") |
|
|
| |
| if 'retrieved_docs' in values: |
| pages_dict = {} |
|
|
| for i in values['retrieved_docs']: |
| doc_metadata = i.metadata |
| key = (doc_metadata['source'], doc_metadata['page']) |
| if key not in pages_dict: |
| pages_dict[key] = { |
| "source": doc_metadata['source'], |
| "page": doc_metadata['page'], |
| "page_label": doc_metadata['page_label'], |
| "relevance_score" : doc_metadata["relevance_score"], |
| } |
| sources_text += f"Source: {doc_metadata['source']}, Page: {doc_metadata['page']}, Page Label: {doc_metadata['page_label']}, Relevance Score : {doc_metadata["relevance_score"]}\n\n" |
| except Exception as e: |
| logger.error(f"Error processing values: {e}") |
| sources_text = "No sources found for the response." |
| |
| if full_response == "": |
| full_response = "Could not find any relevant information in the documents." |
|
|
| st.conversation_history.append( |
| { |
| "role": "assistant", |
| "content": full_response |
| } |
| ) |
| else: |
| full_response += "\n\n" + sources_text |
| |
|
|
| message_placeholder.markdown(full_response) |
| logger.info(f"Full response: {full_response}") |
| |
| |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) |
| except Exception as e: |
| import traceback |
| error_msg = f"Error processing your query: {str(e)}\n\n```\n{traceback.format_exc()}\n```" |
| message_placeholder.error(error_msg) |
| st.session_state.messages.append({"role": "assistant", "content": error_msg}) |