multimodal_rag_chatbot / graph_question_validation.py
vamsidharmuthireddy's picture
Upload 6 files
b3f819d verified
from dotenv import load_dotenv
import os
import streamlit as st
from langchain_aws import BedrockEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from typing import List, Dict, Any
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph, END
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langgraph.graph import MessagesState
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, AIMessage, ToolMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_milvus import Milvus
from utils import extract_text_from_content
from logging_config import setup_logger
from load_vector_db import init_vector_db
from logging_config import setup_logger
import time
from pydantic import BaseModel, Field
logger = setup_logger(__name__)
def init_graph():
"""Initialize the app components and return them."""
with st.spinner("Initializing PDF chat application..."):
embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")
vector_store, compression_retriever = init_vector_db(embeddings)
# Data model
llm = init_chat_model(
"anthropic.claude-3-5-sonnet-20240620-v1:0",
model_provider="bedrock_converse",
temperature=0
)
class State(MessagesState):
user_query: str = ""
query_validation: Dict[str, Any] = {}
retrieved_docs: List[Document] = []
generate_response: bool = False
generate_sample_questions: bool = False
sample_questions : str = ""
final_response : str = ""
class ValidateQuery(BaseModel):
"""Binary score for question validation."""
valid_question: bool = Field(
description="Question is clear and answerable, 'true' or 'false'"
)
response: str = Field(
description="Explanation of the question's validity and suggestions for improvement."
)
def validate_query(state: State):
"""Validate if the question is clear and answerable."""
start = time.time()
# Get the latest human message
human_messages = [msg for msg in state["messages"] if msg.type == "human"]
# logger.info(f"human_messages : {human_messages}")
if not human_messages:
return {"query_validation": {"valid_question": False, "response": ["No question found to validate."]}}
latest_question = human_messages[-1].content
# latest_question = human_messages[-1]
validation_prompt = f"""
You are a question validation assistant. Analyze the following question and determine if it is:
1. Clear and specific (not vague or ambiguous)
2. Focused on a single topic or closely related topics (not trying to address too many different things)
3. Answerable without making assumptions
4. Concise and well-structured
# "If the query is single word or phrase, ask the user to provide a complete question."
# "If the query is not clear, ask for clarification."
# "If the query is not a complete question, ask the user to provide a complete question and provide some sample questions."
# "If the query contains multiple questions, answer only the first question and ask the user to ask the next question."
# "If the query contains complex or compound questions, break them down into simpler parts and answer each part separately."
# "If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base."
Respond with a JSON object containing:
- "valid_question": boolean (true if question is valid, false if not)
- "response": Explain about the validity of the question and suggest improvements. But dont give any example questions.
Examples of invalid questions:
- "Tell me everything about X" (too broad)
- "What about A, B, C, and also how does D relate to E?" (too many topics)
- "Why is it better?" (vague, missing context)
- Questions that would require assumptions about unstated context
Examples of valid questions:
- "What is the definition of X?"
- "How does process A work?"
- "What are the main benefits of technique B?"
"""
try:
structured_llm_grader = llm.with_structured_output(ValidateQuery)
user_prompt = [f"Question to analyze: {latest_question}"]
prompt = [SystemMessage(validation_prompt)] + user_prompt
validation_response = structured_llm_grader.invoke(prompt)
end = time.time()
logger.info(f"Time taken for question validation: {end - start} seconds")
logger.info(f"Question validation result: {validation_response}")
# logger.info(f"type(validation_response) : {type(validation_response)}")
# logger.info(f"validate_query model_ump : {validation_response.dict()}")
return {"query_validation": validation_response,
"user_query": latest_question}
except Exception as e:
logger.info(f"Error in question validation: {e}")
# Default to valid if validation fails
return {"query_validation": {"valid_question": False, "response": "Error in Question validation"},
"user_query": latest_question}
def respond_or_call_retrieve_tool(state: State):
"""Generate tool call for retrieval or respond."""
# logger.info(f"state['messages'] : {state["messages"]}")
start = time.time()
# logger.info(f"state['messages'] : {state['messages']}")
valid_messages = [
msg for msg in state["messages"]
if msg.content
]
if not valid_messages:
return {"messages": []}
llm_with_tools = llm.bind_tools([retrieve_tool])
response = llm_with_tools.invoke(state["messages"])
end = time.time()
logger.info(f"Time taken for query_or_respond_fn LLM invocation: {end - start} seconds")
# MessagesState appends messages to state instead of overwriting
return {"messages": [response]}
@tool(response_format="content_and_artifact")
def retrieve_tool(query: str):
"""Retrieve information related to a query."""
start = time.time()
# retrieved_docs = vector_store.similarity_search(query, k=50)
retrieved_docs = compression_retriever.invoke(input = query,k=30)
serialized = "\n\n".join(
(f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
for doc in retrieved_docs
)
end = time.time()
logger.info(f"Time taken for vectordb retrieval: {end - start} seconds")
logger.info(f"Retrieved {len(retrieved_docs)} documents for query: {query}")
# logger.info(f"retrieved_docs : {retrieved_docs}")
# logger.info(f"retrieved_docs : {retrieved_docs}")
return serialized, retrieved_docs
def extract_retrieved_docs(state: State):
"""Extract retrieved documents from tool messages and update state."""
retrieved_docs = []
logger.info(f"Entered extract_retrieved_docs")
# Look through messages for tool results
for message in state["messages"]:
if isinstance(message, ToolMessage) and message.name == "retrieve_tool":
try:
tool_result = message.artifact
retrieved_docs.extend(tool_result) # Append the retrieved documents
except Exception as e:
logger.info(f"Error parsing tool result: {message.content}")
logger.info(f"Exception: {e}")
continue
return {"retrieved_docs": retrieved_docs}
def wait_for_query_validation_and_retrieved_docs(state: State):
# logger.info(f"state : {state}")
logger.info(f"Entered wait_for_query_validation_and_retrieved_docs")
logger.info(f" state.keys() : {state.keys()}")
logger.info(f"query_validation status : {"query_validation" in state}")
logger.info(f"retrieved_docs status : {"retrieved_docs" in state}")
logger.info(f"state['query_validation'] : {state.get('query_validation')}")
response = {}
if "query_validation" in state and "retrieved_docs" in state:
logger.info(f"Both question validation and retrieved documents are present in state.")
logger.info(f"state['query_validation'] : {state['query_validation']}")
if state["query_validation"].valid_question:
response = {
"generate_sample_questions": False,
"generate_response": True,
}
else:
# If question is not valid, generate sample questions
response = {
"generate_sample_questions": True,
"generate_response": True,
}
else:
# Return empty dict → set flag False
response = { "generate_response": False}
logger.info(f"wait_for_query_validation_and_retrieved_docs response : {response}")
return response
def generate_questions(state: State):
"""Generate sample questions based on the retrieved documents."""
start = time.time()
# Combine the retrieved documents into a single response
combined_response = "\n\n".join(
f"Source: {doc.metadata}\nContent: {doc.page_content}"
for doc in state["retrieved_docs"]
)
question_generation_prompt = f"""
# Question Nudging Prompt
You are a helpful assistant that generates better question suggestions for users based on available document content.
## Context
A user has asked a question that may be unclear, too broad, or not well-phrased. You have access to relevant documents retrieved from a vector database. Your task is to generate 3-5 sample questions that:
1. Are directly answerable using the provided document content
2. Guide the user toward more specific and actionable queries
3. Help clarify what the user might actually be looking for
4. Are phrased clearly and concisely
## Input Format
**User Question:** {state["user_query"]}
**Retrieved Documents:**
{combined_response}
**Question Validity Status:** {state["query_validation"].valid_question}
**Question validity justification:** {state["query_validation"].response}
## Instructions
1. Analyze the user's original question and the retrieved document content
2. Identify key topics, concepts, and specific information available in the documents
3. Generate 3-5 sample questions that:
- Are more specific than the original question
- Can be fully answered using only the provided document content
- Cover different aspects or angles related to the user's intent
- Use clear, direct language
- Avoid assumptions not supported by the documents
4. During response generation, ad justification of question validity which is provided to you in the state["query_validation"].response
## Output Format
{state["query_validation"].response}
Based on the available information, here are some more specific questions you might want to ask:
1. [Specific question 1 based on document content]
2. [Specific question 2 based on document content]
3. [Specific question 3 based on document content]
4. [Specific question 4 based on document content] (if applicable)
5. [Specific question 5 based on document content] (if applicable)
## Important Guidelines
- Only suggest questions that can be answered using the provided documents
- Do not make up information not present in the documents
- Focus on practical, actionable questions
- Vary the scope and angle of suggested questions
- Keep questions concise and clear
- If the original question is valid but could be more specific, acknowledge this in your suggestions
"""
response = llm.invoke(question_generation_prompt)
end = time.time()
logger.info(f"Time taken for generate_fn : {end - start} seconds")
logger.info(f"Time taken for generate_fn LLM invocation: {end - start} seconds")
return {"messages": [response], "sample_questions":response.content}
def generate_answer_to_query(state: State):
logger.info(f"Entered generate_answer_to_query")
"""Generate the final response based on retrieved documents."""
start = time.time()
logger.info(f"state.keys() : {state.keys()}")
if not state["retrieved_docs"]:
logger.info("No relevant documents found.")
return {"messages": ["No relevant documents found."]}
"""Generate answer."""
# Get generated ToolMessages
start = time.time()
sources_text = ""
# logger.info(f"tool_messages {tool_messages}")
tool_messages = state["retrieved_docs"]
for artifact in tool_messages:
page_label = artifact.metadata.get('page_label')
page = artifact.metadata.get('page')
source = artifact.metadata.get('source')
sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n"
# logger.info(source, page, page_label)
# logger.info(f"sources_text { sources_text}")
logger.info(f"sources_text {sources_text}")
docs_content = "\n\n".join(doc.page_content for doc in tool_messages)
system_message_content = (
"You are an assistant for question-answering tasks."
"Use the following pieces of retrieved context to answer the question."
"This is your only source of knowledge."
"If you don't know the answer, say that you don't know and STOP - do not provide related information."
"You are not allowed to make up answers."
"You are not allowed to use any external knowledge."
"You are not allowed to make assumptions."
"If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content."
"Keep your answers strictly limited to information that directly answers the user's specific question."
"When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics."
"Keep your answers accurate and concise to the source content."
"\n\n"
f"{docs_content}"
)
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system")
or (message.type == "ai" and not message.tool_calls)
]
prompt = [SystemMessage(system_message_content)] + conversation_messages
# Run
start_llm = time.time()
response = llm.invoke(prompt)
# return {"messages": [response]}
end = time.time()
logger.info(f"Time taken for generate_fn : {end - start} seconds")
logger.info(f"Time taken for generate_fn LLM invocation: {end - start_llm} seconds")
return {"messages": [response], "final_response": response.content}
def generate_flat_response(state:State):
logger.info(f"Generateing flat response")
messages = AIMessage("Unable to generate response. Please try again later.")
return {"messages": [messages], "final_response": messages}
def final_response_router(state : State):
"""Route to the appropriate response generation function."""
logger.info(f"Entered final_response_router")
logger.info(f"state.keys() : {state.keys()}")
logger.info(f"state.get('generate_response'): {state.get("generate_response")}")
logger.info(f"state.get('generate_sample_questions'): {state.get("generate_sample_questions")}")
logger.info(f"state.get('generate_response'): {state.get("generate_response")}")
if state.get("generate_response"):
logger.info(f"generate_response is True")
if state.get("generate_sample_questions"):
return "generate_questions"
else:
return "generate_answer_to_query"
else:
logger.info(f"generate_response is False")
return END
# return "generate_flat_response"
tools_node = ToolNode(tools=[retrieve_tool])
graph_builder = StateGraph(State)
graph_builder.add_node(node="validate_query", action=validate_query)
graph_builder.add_node(node="respond_or_call_retrieve_tool", action=respond_or_call_retrieve_tool)
graph_builder.add_node(node="tools", action=tools_node)
graph_builder.add_node(node="extract_retrieved_docs", action=extract_retrieved_docs)
graph_builder.add_node(node="wait_for_query_validation_and_retrieved_docs", action=wait_for_query_validation_and_retrieved_docs)
graph_builder.add_node(node="generate_questions", action=generate_questions)
graph_builder.add_node(node="generate_answer_to_query", action=generate_answer_to_query)
graph_builder.add_node(node="generate_flat_response", action=generate_flat_response)
graph_builder.set_entry_point(key="validate_query")
graph_builder.set_entry_point(key="respond_or_call_retrieve_tool")
graph_builder.add_conditional_edges(
source = "respond_or_call_retrieve_tool",
path = tools_condition,
path_map= {END:END, "tools": "tools"}
)
graph_builder.add_edge(start_key="validate_query", end_key="wait_for_query_validation_and_retrieved_docs")
graph_builder.add_edge(start_key="respond_or_call_retrieve_tool", end_key="tools")
graph_builder.add_edge(start_key="tools", end_key="extract_retrieved_docs")
graph_builder.add_edge(start_key="extract_retrieved_docs", end_key="wait_for_query_validation_and_retrieved_docs")
graph_builder.add_conditional_edges(
"wait_for_query_validation_and_retrieved_docs",
final_response_router,
{
"generate_questions": "generate_questions",
"generate_answer_to_query": "generate_answer_to_query",
"generate_flat_response": "generate_flat_response",
END: END
}
)
graph_builder.add_edge(start_key="generate_questions", end_key=END)
graph_builder.add_edge(start_key="generate_answer_to_query", end_key=END)
graph_builder.add_edge(start_key="generate_flat_response", end_key=END)
graph = graph_builder.compile()
st.success("Initialization complete!")
return {"graph": graph}