from dotenv import load_dotenv import os import streamlit as st from langchain_aws import BedrockEmbeddings from langchain_core.vectorstores import InMemoryVectorStore from langchain.chat_models import init_chat_model from langchain_core.documents import Document from typing import List, Dict, Any from langchain_text_splitters import RecursiveCharacterTextSplitter from langgraph.graph import START, StateGraph, END from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader from langgraph.graph import MessagesState from langchain_core.tools import tool from langchain_core.messages import SystemMessage, AIMessage, ToolMessage from langgraph.prebuilt import ToolNode, tools_condition from langchain_milvus import Milvus from utils import extract_text_from_content from logging_config import setup_logger from load_vector_db import init_vector_db from logging_config import setup_logger import time from pydantic import BaseModel, Field logger = setup_logger(__name__) def init_graph(): """Initialize the app components and return them.""" with st.spinner("Initializing PDF chat application..."): embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") vector_store, compression_retriever = init_vector_db(embeddings) # Data model llm = init_chat_model( "anthropic.claude-3-5-sonnet-20240620-v1:0", model_provider="bedrock_converse", temperature=0 ) class State(MessagesState): user_query: str = "" query_validation: Dict[str, Any] = {} retrieved_docs: List[Document] = [] generate_response: bool = False generate_sample_questions: bool = False sample_questions : str = "" final_response : str = "" class ValidateQuery(BaseModel): """Binary score for question validation.""" valid_question: bool = Field( description="Question is clear and answerable, 'true' or 'false'" ) response: str = Field( description="Explanation of the question's validity and suggestions for improvement." ) def validate_query(state: State): """Validate if the question is clear and answerable.""" start = time.time() # Get the latest human message human_messages = [msg for msg in state["messages"] if msg.type == "human"] # logger.info(f"human_messages : {human_messages}") if not human_messages: return {"query_validation": {"valid_question": False, "response": ["No question found to validate."]}} latest_question = human_messages[-1].content # latest_question = human_messages[-1] validation_prompt = f""" You are a question validation assistant. Analyze the following question and determine if it is: 1. Clear and specific (not vague or ambiguous) 2. Focused on a single topic or closely related topics (not trying to address too many different things) 3. Answerable without making assumptions 4. Concise and well-structured # "If the query is single word or phrase, ask the user to provide a complete question." # "If the query is not clear, ask for clarification." # "If the query is not a complete question, ask the user to provide a complete question and provide some sample questions." # "If the query contains multiple questions, answer only the first question and ask the user to ask the next question." # "If the query contains complex or compound questions, break them down into simpler parts and answer each part separately." # "If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base." Respond with a JSON object containing: - "valid_question": boolean (true if question is valid, false if not) - "response": Explain about the validity of the question and suggest improvements. But dont give any example questions. Examples of invalid questions: - "Tell me everything about X" (too broad) - "What about A, B, C, and also how does D relate to E?" (too many topics) - "Why is it better?" (vague, missing context) - Questions that would require assumptions about unstated context Examples of valid questions: - "What is the definition of X?" - "How does process A work?" - "What are the main benefits of technique B?" """ try: structured_llm_grader = llm.with_structured_output(ValidateQuery) user_prompt = [f"Question to analyze: {latest_question}"] prompt = [SystemMessage(validation_prompt)] + user_prompt validation_response = structured_llm_grader.invoke(prompt) end = time.time() logger.info(f"Time taken for question validation: {end - start} seconds") logger.info(f"Question validation result: {validation_response}") # logger.info(f"type(validation_response) : {type(validation_response)}") # logger.info(f"validate_query model_ump : {validation_response.dict()}") return {"query_validation": validation_response, "user_query": latest_question} except Exception as e: logger.info(f"Error in question validation: {e}") # Default to valid if validation fails return {"query_validation": {"valid_question": False, "response": "Error in Question validation"}, "user_query": latest_question} def respond_or_call_retrieve_tool(state: State): """Generate tool call for retrieval or respond.""" # logger.info(f"state['messages'] : {state["messages"]}") start = time.time() # logger.info(f"state['messages'] : {state['messages']}") valid_messages = [ msg for msg in state["messages"] if msg.content ] if not valid_messages: return {"messages": []} llm_with_tools = llm.bind_tools([retrieve_tool]) response = llm_with_tools.invoke(state["messages"]) end = time.time() logger.info(f"Time taken for query_or_respond_fn LLM invocation: {end - start} seconds") # MessagesState appends messages to state instead of overwriting return {"messages": [response]} @tool(response_format="content_and_artifact") def retrieve_tool(query: str): """Retrieve information related to a query.""" start = time.time() # retrieved_docs = vector_store.similarity_search(query, k=50) retrieved_docs = compression_retriever.invoke(input = query,k=30) serialized = "\n\n".join( (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") for doc in retrieved_docs ) end = time.time() logger.info(f"Time taken for vectordb retrieval: {end - start} seconds") logger.info(f"Retrieved {len(retrieved_docs)} documents for query: {query}") # logger.info(f"retrieved_docs : {retrieved_docs}") # logger.info(f"retrieved_docs : {retrieved_docs}") return serialized, retrieved_docs def extract_retrieved_docs(state: State): """Extract retrieved documents from tool messages and update state.""" retrieved_docs = [] logger.info(f"Entered extract_retrieved_docs") # Look through messages for tool results for message in state["messages"]: if isinstance(message, ToolMessage) and message.name == "retrieve_tool": try: tool_result = message.artifact retrieved_docs.extend(tool_result) # Append the retrieved documents except Exception as e: logger.info(f"Error parsing tool result: {message.content}") logger.info(f"Exception: {e}") continue return {"retrieved_docs": retrieved_docs} def wait_for_query_validation_and_retrieved_docs(state: State): # logger.info(f"state : {state}") logger.info(f"Entered wait_for_query_validation_and_retrieved_docs") logger.info(f" state.keys() : {state.keys()}") logger.info(f"query_validation status : {"query_validation" in state}") logger.info(f"retrieved_docs status : {"retrieved_docs" in state}") logger.info(f"state['query_validation'] : {state.get('query_validation')}") response = {} if "query_validation" in state and "retrieved_docs" in state: logger.info(f"Both question validation and retrieved documents are present in state.") logger.info(f"state['query_validation'] : {state['query_validation']}") if state["query_validation"].valid_question: response = { "generate_sample_questions": False, "generate_response": True, } else: # If question is not valid, generate sample questions response = { "generate_sample_questions": True, "generate_response": True, } else: # Return empty dict → set flag False response = { "generate_response": False} logger.info(f"wait_for_query_validation_and_retrieved_docs response : {response}") return response def generate_questions(state: State): """Generate sample questions based on the retrieved documents.""" start = time.time() # Combine the retrieved documents into a single response combined_response = "\n\n".join( f"Source: {doc.metadata}\nContent: {doc.page_content}" for doc in state["retrieved_docs"] ) question_generation_prompt = f""" # Question Nudging Prompt You are a helpful assistant that generates better question suggestions for users based on available document content. ## Context A user has asked a question that may be unclear, too broad, or not well-phrased. You have access to relevant documents retrieved from a vector database. Your task is to generate 3-5 sample questions that: 1. Are directly answerable using the provided document content 2. Guide the user toward more specific and actionable queries 3. Help clarify what the user might actually be looking for 4. Are phrased clearly and concisely ## Input Format **User Question:** {state["user_query"]} **Retrieved Documents:** {combined_response} **Question Validity Status:** {state["query_validation"].valid_question} **Question validity justification:** {state["query_validation"].response} ## Instructions 1. Analyze the user's original question and the retrieved document content 2. Identify key topics, concepts, and specific information available in the documents 3. Generate 3-5 sample questions that: - Are more specific than the original question - Can be fully answered using only the provided document content - Cover different aspects or angles related to the user's intent - Use clear, direct language - Avoid assumptions not supported by the documents 4. During response generation, ad justification of question validity which is provided to you in the state["query_validation"].response ## Output Format {state["query_validation"].response} Based on the available information, here are some more specific questions you might want to ask: 1. [Specific question 1 based on document content] 2. [Specific question 2 based on document content] 3. [Specific question 3 based on document content] 4. [Specific question 4 based on document content] (if applicable) 5. [Specific question 5 based on document content] (if applicable) ## Important Guidelines - Only suggest questions that can be answered using the provided documents - Do not make up information not present in the documents - Focus on practical, actionable questions - Vary the scope and angle of suggested questions - Keep questions concise and clear - If the original question is valid but could be more specific, acknowledge this in your suggestions """ response = llm.invoke(question_generation_prompt) end = time.time() logger.info(f"Time taken for generate_fn : {end - start} seconds") logger.info(f"Time taken for generate_fn LLM invocation: {end - start} seconds") return {"messages": [response], "sample_questions":response.content} def generate_answer_to_query(state: State): logger.info(f"Entered generate_answer_to_query") """Generate the final response based on retrieved documents.""" start = time.time() logger.info(f"state.keys() : {state.keys()}") if not state["retrieved_docs"]: logger.info("No relevant documents found.") return {"messages": ["No relevant documents found."]} """Generate answer.""" # Get generated ToolMessages start = time.time() sources_text = "" # logger.info(f"tool_messages {tool_messages}") tool_messages = state["retrieved_docs"] for artifact in tool_messages: page_label = artifact.metadata.get('page_label') page = artifact.metadata.get('page') source = artifact.metadata.get('source') sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n" # logger.info(source, page, page_label) # logger.info(f"sources_text { sources_text}") logger.info(f"sources_text {sources_text}") docs_content = "\n\n".join(doc.page_content for doc in tool_messages) system_message_content = ( "You are an assistant for question-answering tasks." "Use the following pieces of retrieved context to answer the question." "This is your only source of knowledge." "If you don't know the answer, say that you don't know and STOP - do not provide related information." "You are not allowed to make up answers." "You are not allowed to use any external knowledge." "You are not allowed to make assumptions." "If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content." "Keep your answers strictly limited to information that directly answers the user's specific question." "When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics." "Keep your answers accurate and concise to the source content." "\n\n" f"{docs_content}" ) conversation_messages = [ message for message in state["messages"] if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls) ] prompt = [SystemMessage(system_message_content)] + conversation_messages # Run start_llm = time.time() response = llm.invoke(prompt) # return {"messages": [response]} end = time.time() logger.info(f"Time taken for generate_fn : {end - start} seconds") logger.info(f"Time taken for generate_fn LLM invocation: {end - start_llm} seconds") return {"messages": [response], "final_response": response.content} def generate_flat_response(state:State): logger.info(f"Generateing flat response") messages = AIMessage("Unable to generate response. Please try again later.") return {"messages": [messages], "final_response": messages} def final_response_router(state : State): """Route to the appropriate response generation function.""" logger.info(f"Entered final_response_router") logger.info(f"state.keys() : {state.keys()}") logger.info(f"state.get('generate_response'): {state.get("generate_response")}") logger.info(f"state.get('generate_sample_questions'): {state.get("generate_sample_questions")}") logger.info(f"state.get('generate_response'): {state.get("generate_response")}") if state.get("generate_response"): logger.info(f"generate_response is True") if state.get("generate_sample_questions"): return "generate_questions" else: return "generate_answer_to_query" else: logger.info(f"generate_response is False") return END # return "generate_flat_response" tools_node = ToolNode(tools=[retrieve_tool]) graph_builder = StateGraph(State) graph_builder.add_node(node="validate_query", action=validate_query) graph_builder.add_node(node="respond_or_call_retrieve_tool", action=respond_or_call_retrieve_tool) graph_builder.add_node(node="tools", action=tools_node) graph_builder.add_node(node="extract_retrieved_docs", action=extract_retrieved_docs) graph_builder.add_node(node="wait_for_query_validation_and_retrieved_docs", action=wait_for_query_validation_and_retrieved_docs) graph_builder.add_node(node="generate_questions", action=generate_questions) graph_builder.add_node(node="generate_answer_to_query", action=generate_answer_to_query) graph_builder.add_node(node="generate_flat_response", action=generate_flat_response) graph_builder.set_entry_point(key="validate_query") graph_builder.set_entry_point(key="respond_or_call_retrieve_tool") graph_builder.add_conditional_edges( source = "respond_or_call_retrieve_tool", path = tools_condition, path_map= {END:END, "tools": "tools"} ) graph_builder.add_edge(start_key="validate_query", end_key="wait_for_query_validation_and_retrieved_docs") graph_builder.add_edge(start_key="respond_or_call_retrieve_tool", end_key="tools") graph_builder.add_edge(start_key="tools", end_key="extract_retrieved_docs") graph_builder.add_edge(start_key="extract_retrieved_docs", end_key="wait_for_query_validation_and_retrieved_docs") graph_builder.add_conditional_edges( "wait_for_query_validation_and_retrieved_docs", final_response_router, { "generate_questions": "generate_questions", "generate_answer_to_query": "generate_answer_to_query", "generate_flat_response": "generate_flat_response", END: END } ) graph_builder.add_edge(start_key="generate_questions", end_key=END) graph_builder.add_edge(start_key="generate_answer_to_query", end_key=END) graph_builder.add_edge(start_key="generate_flat_response", end_key=END) graph = graph_builder.compile() st.success("Initialization complete!") return {"graph": graph}