|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
import streamlit as st |
|
|
from langchain_aws import BedrockEmbeddings |
|
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain_core.documents import Document |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
|
|
|
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader |
|
|
from langgraph.graph import MessagesState |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import SystemMessage, AIMessage, ToolMessage |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langchain_milvus import Milvus |
|
|
from utils import extract_text_from_content |
|
|
from logging_config import setup_logger |
|
|
from load_vector_db import init_vector_db |
|
|
from logging_config import setup_logger |
|
|
import time |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
def init_graph(): |
|
|
"""Initialize the app components and return them.""" |
|
|
with st.spinner("Initializing PDF chat application..."): |
|
|
embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") |
|
|
|
|
|
vector_store, compression_retriever = init_vector_db(embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = init_chat_model( |
|
|
"anthropic.claude-3-5-sonnet-20240620-v1:0", |
|
|
model_provider="bedrock_converse", |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
class State(MessagesState): |
|
|
user_query: str = "" |
|
|
query_validation: Dict[str, Any] = {} |
|
|
retrieved_docs: List[Document] = [] |
|
|
generate_response: bool = False |
|
|
generate_sample_questions: bool = False |
|
|
sample_questions : str = "" |
|
|
final_response : str = "" |
|
|
|
|
|
class ValidateQuery(BaseModel): |
|
|
"""Binary score for question validation.""" |
|
|
|
|
|
valid_question: bool = Field( |
|
|
description="Question is clear and answerable, 'true' or 'false'" |
|
|
) |
|
|
response: str = Field( |
|
|
description="Explanation of the question's validity and suggestions for improvement." |
|
|
) |
|
|
|
|
|
|
|
|
def validate_query(state: State): |
|
|
"""Validate if the question is clear and answerable.""" |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
human_messages = [msg for msg in state["messages"] if msg.type == "human"] |
|
|
|
|
|
if not human_messages: |
|
|
return {"query_validation": {"valid_question": False, "response": ["No question found to validate."]}} |
|
|
|
|
|
latest_question = human_messages[-1].content |
|
|
|
|
|
|
|
|
validation_prompt = f""" |
|
|
You are a question validation assistant. Analyze the following question and determine if it is: |
|
|
1. Clear and specific (not vague or ambiguous) |
|
|
2. Focused on a single topic or closely related topics (not trying to address too many different things) |
|
|
3. Answerable without making assumptions |
|
|
4. Concise and well-structured |
|
|
|
|
|
# "If the query is single word or phrase, ask the user to provide a complete question." |
|
|
# "If the query is not clear, ask for clarification." |
|
|
# "If the query is not a complete question, ask the user to provide a complete question and provide some sample questions." |
|
|
# "If the query contains multiple questions, answer only the first question and ask the user to ask the next question." |
|
|
# "If the query contains complex or compound questions, break them down into simpler parts and answer each part separately." |
|
|
# "If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base." |
|
|
|
|
|
Respond with a JSON object containing: |
|
|
- "valid_question": boolean (true if question is valid, false if not) |
|
|
- "response": Explain about the validity of the question and suggest improvements. But dont give any example questions. |
|
|
|
|
|
|
|
|
Examples of invalid questions: |
|
|
- "Tell me everything about X" (too broad) |
|
|
- "What about A, B, C, and also how does D relate to E?" (too many topics) |
|
|
- "Why is it better?" (vague, missing context) |
|
|
- Questions that would require assumptions about unstated context |
|
|
|
|
|
Examples of valid questions: |
|
|
- "What is the definition of X?" |
|
|
- "How does process A work?" |
|
|
- "What are the main benefits of technique B?" |
|
|
""" |
|
|
|
|
|
try: |
|
|
structured_llm_grader = llm.with_structured_output(ValidateQuery) |
|
|
|
|
|
user_prompt = [f"Question to analyze: {latest_question}"] |
|
|
|
|
|
prompt = [SystemMessage(validation_prompt)] + user_prompt |
|
|
validation_response = structured_llm_grader.invoke(prompt) |
|
|
|
|
|
end = time.time() |
|
|
logger.info(f"Time taken for question validation: {end - start} seconds") |
|
|
logger.info(f"Question validation result: {validation_response}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"query_validation": validation_response, |
|
|
"user_query": latest_question} |
|
|
|
|
|
except Exception as e: |
|
|
logger.info(f"Error in question validation: {e}") |
|
|
|
|
|
return {"query_validation": {"valid_question": False, "response": "Error in Question validation"}, |
|
|
"user_query": latest_question} |
|
|
|
|
|
|
|
|
def respond_or_call_retrieve_tool(state: State): |
|
|
"""Generate tool call for retrieval or respond.""" |
|
|
|
|
|
start = time.time() |
|
|
|
|
|
valid_messages = [ |
|
|
msg for msg in state["messages"] |
|
|
if msg.content |
|
|
] |
|
|
|
|
|
if not valid_messages: |
|
|
return {"messages": []} |
|
|
llm_with_tools = llm.bind_tools([retrieve_tool]) |
|
|
response = llm_with_tools.invoke(state["messages"]) |
|
|
end = time.time() |
|
|
logger.info(f"Time taken for query_or_respond_fn LLM invocation: {end - start} seconds") |
|
|
|
|
|
|
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
@tool(response_format="content_and_artifact") |
|
|
def retrieve_tool(query: str): |
|
|
"""Retrieve information related to a query.""" |
|
|
start = time.time() |
|
|
|
|
|
retrieved_docs = compression_retriever.invoke(input = query,k=30) |
|
|
serialized = "\n\n".join( |
|
|
(f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") |
|
|
for doc in retrieved_docs |
|
|
) |
|
|
end = time.time() |
|
|
logger.info(f"Time taken for vectordb retrieval: {end - start} seconds") |
|
|
logger.info(f"Retrieved {len(retrieved_docs)} documents for query: {query}") |
|
|
|
|
|
|
|
|
return serialized, retrieved_docs |
|
|
|
|
|
|
|
|
def extract_retrieved_docs(state: State): |
|
|
"""Extract retrieved documents from tool messages and update state.""" |
|
|
retrieved_docs = [] |
|
|
logger.info(f"Entered extract_retrieved_docs") |
|
|
|
|
|
|
|
|
for message in state["messages"]: |
|
|
if isinstance(message, ToolMessage) and message.name == "retrieve_tool": |
|
|
try: |
|
|
tool_result = message.artifact |
|
|
retrieved_docs.extend(tool_result) |
|
|
|
|
|
except Exception as e: |
|
|
logger.info(f"Error parsing tool result: {message.content}") |
|
|
logger.info(f"Exception: {e}") |
|
|
continue |
|
|
|
|
|
return {"retrieved_docs": retrieved_docs} |
|
|
|
|
|
|
|
|
def wait_for_query_validation_and_retrieved_docs(state: State): |
|
|
|
|
|
logger.info(f"Entered wait_for_query_validation_and_retrieved_docs") |
|
|
logger.info(f" state.keys() : {state.keys()}") |
|
|
|
|
|
logger.info(f"query_validation status : {"query_validation" in state}") |
|
|
logger.info(f"retrieved_docs status : {"retrieved_docs" in state}") |
|
|
logger.info(f"state['query_validation'] : {state.get('query_validation')}") |
|
|
|
|
|
response = {} |
|
|
|
|
|
|
|
|
if "query_validation" in state and "retrieved_docs" in state: |
|
|
logger.info(f"Both question validation and retrieved documents are present in state.") |
|
|
logger.info(f"state['query_validation'] : {state['query_validation']}") |
|
|
|
|
|
if state["query_validation"].valid_question: |
|
|
response = { |
|
|
"generate_sample_questions": False, |
|
|
"generate_response": True, |
|
|
} |
|
|
else: |
|
|
|
|
|
response = { |
|
|
"generate_sample_questions": True, |
|
|
"generate_response": True, |
|
|
} |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
response = { "generate_response": False} |
|
|
|
|
|
logger.info(f"wait_for_query_validation_and_retrieved_docs response : {response}") |
|
|
|
|
|
|
|
|
return response |
|
|
|
|
|
def generate_questions(state: State): |
|
|
"""Generate sample questions based on the retrieved documents.""" |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
combined_response = "\n\n".join( |
|
|
f"Source: {doc.metadata}\nContent: {doc.page_content}" |
|
|
for doc in state["retrieved_docs"] |
|
|
) |
|
|
|
|
|
question_generation_prompt = f""" |
|
|
# Question Nudging Prompt |
|
|
|
|
|
You are a helpful assistant that generates better question suggestions for users based on available document content. |
|
|
|
|
|
## Context |
|
|
A user has asked a question that may be unclear, too broad, or not well-phrased. You have access to relevant documents retrieved from a vector database. Your task is to generate 3-5 sample questions that: |
|
|
|
|
|
1. Are directly answerable using the provided document content |
|
|
2. Guide the user toward more specific and actionable queries |
|
|
3. Help clarify what the user might actually be looking for |
|
|
4. Are phrased clearly and concisely |
|
|
|
|
|
## Input Format |
|
|
**User Question:** {state["user_query"]} |
|
|
|
|
|
**Retrieved Documents:** |
|
|
{combined_response} |
|
|
|
|
|
**Question Validity Status:** {state["query_validation"].valid_question} |
|
|
|
|
|
**Question validity justification:** {state["query_validation"].response} |
|
|
|
|
|
|
|
|
## Instructions |
|
|
1. Analyze the user's original question and the retrieved document content |
|
|
2. Identify key topics, concepts, and specific information available in the documents |
|
|
3. Generate 3-5 sample questions that: |
|
|
- Are more specific than the original question |
|
|
- Can be fully answered using only the provided document content |
|
|
- Cover different aspects or angles related to the user's intent |
|
|
- Use clear, direct language |
|
|
- Avoid assumptions not supported by the documents |
|
|
4. During response generation, ad justification of question validity which is provided to you in the state["query_validation"].response |
|
|
|
|
|
## Output Format |
|
|
{state["query_validation"].response} |
|
|
|
|
|
Based on the available information, here are some more specific questions you might want to ask: |
|
|
|
|
|
1. [Specific question 1 based on document content] |
|
|
2. [Specific question 2 based on document content] |
|
|
3. [Specific question 3 based on document content] |
|
|
4. [Specific question 4 based on document content] (if applicable) |
|
|
5. [Specific question 5 based on document content] (if applicable) |
|
|
|
|
|
## Important Guidelines |
|
|
- Only suggest questions that can be answered using the provided documents |
|
|
- Do not make up information not present in the documents |
|
|
- Focus on practical, actionable questions |
|
|
- Vary the scope and angle of suggested questions |
|
|
- Keep questions concise and clear |
|
|
- If the original question is valid but could be more specific, acknowledge this in your suggestions |
|
|
""" |
|
|
|
|
|
response = llm.invoke(question_generation_prompt) |
|
|
|
|
|
end = time.time() |
|
|
logger.info(f"Time taken for generate_fn : {end - start} seconds") |
|
|
logger.info(f"Time taken for generate_fn LLM invocation: {end - start} seconds") |
|
|
|
|
|
|
|
|
return {"messages": [response], "sample_questions":response.content} |
|
|
|
|
|
|
|
|
|
|
|
def generate_answer_to_query(state: State): |
|
|
|
|
|
logger.info(f"Entered generate_answer_to_query") |
|
|
|
|
|
"""Generate the final response based on retrieved documents.""" |
|
|
start = time.time() |
|
|
|
|
|
logger.info(f"state.keys() : {state.keys()}") |
|
|
|
|
|
if not state["retrieved_docs"]: |
|
|
logger.info("No relevant documents found.") |
|
|
return {"messages": ["No relevant documents found."]} |
|
|
|
|
|
"""Generate answer.""" |
|
|
|
|
|
start = time.time() |
|
|
|
|
|
sources_text = "" |
|
|
|
|
|
|
|
|
|
|
|
tool_messages = state["retrieved_docs"] |
|
|
for artifact in tool_messages: |
|
|
page_label = artifact.metadata.get('page_label') |
|
|
page = artifact.metadata.get('page') |
|
|
source = artifact.metadata.get('source') |
|
|
|
|
|
sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n" |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"sources_text {sources_text}") |
|
|
|
|
|
docs_content = "\n\n".join(doc.page_content for doc in tool_messages) |
|
|
system_message_content = ( |
|
|
"You are an assistant for question-answering tasks." |
|
|
"Use the following pieces of retrieved context to answer the question." |
|
|
"This is your only source of knowledge." |
|
|
"If you don't know the answer, say that you don't know and STOP - do not provide related information." |
|
|
"You are not allowed to make up answers." |
|
|
"You are not allowed to use any external knowledge." |
|
|
"You are not allowed to make assumptions." |
|
|
"If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content." |
|
|
"Keep your answers strictly limited to information that directly answers the user's specific question." |
|
|
"When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics." |
|
|
"Keep your answers accurate and concise to the source content." |
|
|
"\n\n" |
|
|
f"{docs_content}" |
|
|
|
|
|
) |
|
|
conversation_messages = [ |
|
|
message |
|
|
for message in state["messages"] |
|
|
if message.type in ("human", "system") |
|
|
or (message.type == "ai" and not message.tool_calls) |
|
|
] |
|
|
prompt = [SystemMessage(system_message_content)] + conversation_messages |
|
|
|
|
|
|
|
|
start_llm = time.time() |
|
|
response = llm.invoke(prompt) |
|
|
|
|
|
|
|
|
end = time.time() |
|
|
logger.info(f"Time taken for generate_fn : {end - start} seconds") |
|
|
logger.info(f"Time taken for generate_fn LLM invocation: {end - start_llm} seconds") |
|
|
|
|
|
|
|
|
return {"messages": [response], "final_response": response.content} |
|
|
|
|
|
|
|
|
def generate_flat_response(state:State): |
|
|
logger.info(f"Generateing flat response") |
|
|
|
|
|
messages = AIMessage("Unable to generate response. Please try again later.") |
|
|
|
|
|
return {"messages": [messages], "final_response": messages} |
|
|
|
|
|
|
|
|
|
|
|
def final_response_router(state : State): |
|
|
"""Route to the appropriate response generation function.""" |
|
|
logger.info(f"Entered final_response_router") |
|
|
logger.info(f"state.keys() : {state.keys()}") |
|
|
logger.info(f"state.get('generate_response'): {state.get("generate_response")}") |
|
|
logger.info(f"state.get('generate_sample_questions'): {state.get("generate_sample_questions")}") |
|
|
logger.info(f"state.get('generate_response'): {state.get("generate_response")}") |
|
|
|
|
|
if state.get("generate_response"): |
|
|
logger.info(f"generate_response is True") |
|
|
if state.get("generate_sample_questions"): |
|
|
return "generate_questions" |
|
|
else: |
|
|
return "generate_answer_to_query" |
|
|
else: |
|
|
logger.info(f"generate_response is False") |
|
|
return END |
|
|
|
|
|
|
|
|
|
|
|
tools_node = ToolNode(tools=[retrieve_tool]) |
|
|
|
|
|
graph_builder = StateGraph(State) |
|
|
|
|
|
graph_builder.add_node(node="validate_query", action=validate_query) |
|
|
graph_builder.add_node(node="respond_or_call_retrieve_tool", action=respond_or_call_retrieve_tool) |
|
|
graph_builder.add_node(node="tools", action=tools_node) |
|
|
graph_builder.add_node(node="extract_retrieved_docs", action=extract_retrieved_docs) |
|
|
graph_builder.add_node(node="wait_for_query_validation_and_retrieved_docs", action=wait_for_query_validation_and_retrieved_docs) |
|
|
graph_builder.add_node(node="generate_questions", action=generate_questions) |
|
|
graph_builder.add_node(node="generate_answer_to_query", action=generate_answer_to_query) |
|
|
graph_builder.add_node(node="generate_flat_response", action=generate_flat_response) |
|
|
|
|
|
graph_builder.set_entry_point(key="validate_query") |
|
|
graph_builder.set_entry_point(key="respond_or_call_retrieve_tool") |
|
|
|
|
|
graph_builder.add_conditional_edges( |
|
|
source = "respond_or_call_retrieve_tool", |
|
|
path = tools_condition, |
|
|
path_map= {END:END, "tools": "tools"} |
|
|
) |
|
|
|
|
|
graph_builder.add_edge(start_key="validate_query", end_key="wait_for_query_validation_and_retrieved_docs") |
|
|
graph_builder.add_edge(start_key="respond_or_call_retrieve_tool", end_key="tools") |
|
|
graph_builder.add_edge(start_key="tools", end_key="extract_retrieved_docs") |
|
|
graph_builder.add_edge(start_key="extract_retrieved_docs", end_key="wait_for_query_validation_and_retrieved_docs") |
|
|
|
|
|
|
|
|
|
|
|
graph_builder.add_conditional_edges( |
|
|
"wait_for_query_validation_and_retrieved_docs", |
|
|
final_response_router, |
|
|
{ |
|
|
"generate_questions": "generate_questions", |
|
|
"generate_answer_to_query": "generate_answer_to_query", |
|
|
"generate_flat_response": "generate_flat_response", |
|
|
END: END |
|
|
} |
|
|
) |
|
|
|
|
|
graph_builder.add_edge(start_key="generate_questions", end_key=END) |
|
|
graph_builder.add_edge(start_key="generate_answer_to_query", end_key=END) |
|
|
graph_builder.add_edge(start_key="generate_flat_response", end_key=END) |
|
|
|
|
|
|
|
|
graph = graph_builder.compile() |
|
|
|
|
|
|
|
|
st.success("Initialization complete!") |
|
|
return {"graph": graph} |
|
|
|
|
|
|