|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
import streamlit as st |
|
|
from langchain_aws import BedrockEmbeddings |
|
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain_core.documents import Document |
|
|
from typing_extensions import List, Dict |
|
|
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
|
|
|
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader |
|
|
from langgraph.graph import MessagesState |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import SystemMessage |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langchain_milvus import Milvus |
|
|
from utils import extract_text_from_content |
|
|
from logging_config import setup_logger |
|
|
from load_vector_db import init_vector_db |
|
|
from logging_config import setup_logger |
|
|
import time |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
def init_graph(): |
|
|
"""Initialize the app components and return them.""" |
|
|
with st.spinner("Initializing PDF chat application..."): |
|
|
|
|
|
llm = init_chat_model( |
|
|
"anthropic.claude-3-5-sonnet-20240620-v1:0", |
|
|
model_provider="bedrock_converse", |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") |
|
|
|
|
|
vector_store, compression_retriever = init_vector_db(embeddings) |
|
|
|
|
|
class State(MessagesState): |
|
|
context: List[Document] |
|
|
|
|
|
|
|
|
@tool(response_format="content_and_artifact") |
|
|
def retrieve_tool(query: str): |
|
|
"""Retrieve information related to a query.""" |
|
|
start = time.time() |
|
|
|
|
|
retrieved_docs = compression_retriever.invoke(input = query,k=10) |
|
|
serialized = "\n\n".join( |
|
|
(f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") |
|
|
for doc in retrieved_docs |
|
|
) |
|
|
end = time.time() |
|
|
logger.info(f"Time taken for vectordb retrieval: {end - start} seconds") |
|
|
|
|
|
logger.info(f"retrieved_docs num: {len(retrieved_docs)}") |
|
|
logger.info(f"retrieved_docs : {retrieved_docs}") |
|
|
return serialized, retrieved_docs |
|
|
|
|
|
|
|
|
def query_or_respond_fn(state: State): |
|
|
"""Generate tool call for retrieval or respond.""" |
|
|
|
|
|
start = time.time() |
|
|
logger.info(f"state['messages'] : {state['messages']}") |
|
|
valid_messages = [ |
|
|
msg for msg in state["messages"] |
|
|
if msg.content |
|
|
] |
|
|
|
|
|
if not valid_messages: |
|
|
return {"messages": []} |
|
|
llm_with_tools = llm.bind_tools([retrieve_tool]) |
|
|
response = llm_with_tools.invoke(state["messages"]) |
|
|
end = time.time() |
|
|
logger.info(f"Time taken for query_or_respond_fn LLM invocation: {end - start} seconds") |
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
def generate_fn(state: State): |
|
|
"""Generate answer.""" |
|
|
|
|
|
start = time.time() |
|
|
recent_tool_messages = [] |
|
|
for message in reversed(state["messages"]): |
|
|
if message.type == "tool": |
|
|
recent_tool_messages.append(message) |
|
|
else: |
|
|
break |
|
|
tool_messages = recent_tool_messages[::-1] |
|
|
|
|
|
|
|
|
sources_text = "" |
|
|
|
|
|
|
|
|
logger.info(f"tool_messages {tool_messages}") |
|
|
|
|
|
tool_messages_latest = tool_messages[0] |
|
|
for artifact in tool_messages_latest.artifact: |
|
|
|
|
|
page_label = artifact.metadata.get('page_label') |
|
|
page = artifact.metadata.get('page') |
|
|
source = artifact.metadata.get('source') |
|
|
|
|
|
sources_text += f"Source: {source}, Page: {page}, Page Label: {page_label}\n" |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"sources_text {sources_text}") |
|
|
|
|
|
docs_content = "\n\n".join(doc.content for doc in tool_messages) |
|
|
system_message_content = ( |
|
|
"You are an assistant for question-answering tasks." |
|
|
"Use the following pieces of retrieved context to answer the question." |
|
|
"This is your only source of knowledge." |
|
|
"If you don't know the answer, say that you don't know and STOP - do not provide related information." |
|
|
"You are not allowed to make up answers." |
|
|
"You are not allowed to use any external knowledge." |
|
|
"You are not allowed to make assumptions." |
|
|
"If the query is not clearly and directly addressed in the knowledge source, simply state that you don't have enough information and DO NOT elaborate with tangentially related content." |
|
|
"Keep your answers strictly limited to information that directly answers the user's specific question." |
|
|
"When information is insufficient, acknowledge this limitation in one sentence without expanding into related topics." |
|
|
"If the query is single word or phrase, ask the user to provide a complete question." |
|
|
"If the query is not clear, ask for clarification." |
|
|
"If the query is not a complete question, ask the user to provide a complete question and provide some sample questions." |
|
|
"If the query contains multiple questions, answer only the first question and ask the user to ask the next question." |
|
|
"If the query contains complex or compound questions, break them down into simpler parts and answer each part separately." |
|
|
"If the query is not related to the given knowledge source, mention that you can only answer from the knowledge base." |
|
|
"Keep your answers accurate and concise to the source content." |
|
|
"\n\n" |
|
|
f"{docs_content}" |
|
|
|
|
|
) |
|
|
conversation_messages = [ |
|
|
message |
|
|
for message in state["messages"] |
|
|
if message.type in ("human", "system") |
|
|
or (message.type == "ai" and not message.tool_calls) |
|
|
] |
|
|
prompt = [SystemMessage(system_message_content)] + conversation_messages |
|
|
|
|
|
|
|
|
start_llm = time.time() |
|
|
response = llm.invoke(prompt) |
|
|
|
|
|
context = [] |
|
|
for tool_message in tool_messages: |
|
|
context.extend(tool_message.artifact) |
|
|
|
|
|
end = time.time() |
|
|
logger.info(f"Time taken for generate_fn : {end - start} seconds") |
|
|
logger.info(f"Time taken for generate_fn LLM invocation: {end - start_llm} seconds") |
|
|
|
|
|
|
|
|
return {"messages": [response], "context": context} |
|
|
|
|
|
|
|
|
tools_node = ToolNode([retrieve_tool]) |
|
|
|
|
|
|
|
|
graph_builder = StateGraph(MessagesState) |
|
|
graph_builder.add_node("query_or_respond", query_or_respond_fn) |
|
|
graph_builder.add_node("tools", tools_node) |
|
|
graph_builder.add_node("generate", generate_fn) |
|
|
graph_builder.set_entry_point("query_or_respond") |
|
|
graph_builder.add_conditional_edges( |
|
|
"query_or_respond", |
|
|
tools_condition, |
|
|
{END: END, "tools": "tools"}, |
|
|
) |
|
|
graph_builder.add_edge("tools", "generate") |
|
|
graph_builder.add_edge("generate", END) |
|
|
graph = graph_builder.compile() |
|
|
|
|
|
st.success("Initialization complete!") |
|
|
return {"graph": graph} |
|
|
|
|
|
|