File size: 6,624 Bytes
22dcdfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import logging
import os
from typing import Any
from langchain_aws import AmazonKnowledgeBasesRetriever
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langchain_core.runnables.base import RunnableSequence
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps
from core import get_model, settings
logger = logging.getLogger(__name__)
# Define the state
class AgentState(MessagesState, total=False):
"""State for Knowledge Base agent."""
remaining_steps: RemainingSteps
retrieved_documents: list[dict[str, Any]]
kb_documents: str
# Create the retriever
def get_kb_retriever():
"""Create and return a Knowledge Base retriever instance."""
# Get the Knowledge Base ID from environment
kb_id = os.environ.get("AWS_KB_ID", "")
if not kb_id:
raise ValueError("AWS_KB_ID environment variable must be set")
# Create the retriever with the specified Knowledge Base ID
retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id=kb_id,
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 3,
}
},
)
return retriever
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
"""Wrap the model with a system prompt for the Knowledge Base agent."""
def create_system_message(state):
base_prompt = """You are a helpful assistant that provides accurate information based on retrieved documents.
You will receive a query along with relevant documents retrieved from a knowledge base. Use these documents to inform your response.
Follow these guidelines:
1. Base your answer primarily on the retrieved documents
2. If the documents contain the answer, provide it clearly and concisely
3. If the documents are insufficient, state that you don't have enough information
4. Never make up facts or information not present in the documents
5. Always cite the source documents when referring to specific information
6. If the documents contradict each other, acknowledge this and explain the different perspectives
Format your response in a clear, conversational manner. Use markdown formatting when appropriate.
"""
# Check if documents were retrieved
if "kb_documents" in state:
# Append document information to the system prompt
document_prompt = f"\n\nI've retrieved the following documents that may be relevant to the query:\n\n{state['kb_documents']}\n\nPlease use these documents to inform your response to the user's query. Only use information from these documents and clearly indicate when you are unsure."
return [SystemMessage(content=base_prompt + document_prompt)] + state["messages"]
else:
# No documents were retrieved
no_docs_prompt = (
"\n\nNo relevant documents were found in the knowledge base for this query."
)
return [SystemMessage(content=base_prompt + no_docs_prompt)] + state["messages"]
preprocessor = RunnableLambda(
create_system_message,
name="StateModifier",
)
return RunnableSequence(preprocessor, model)
async def retrieve_documents(state: AgentState, config: RunnableConfig) -> AgentState:
"""Retrieve relevant documents from the knowledge base."""
# Get the last human message
human_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
if not human_messages:
# Include messages from original state
return {"messages": [], "retrieved_documents": []}
# Use the last human message as the query
query = human_messages[-1].content
try:
# Initialize the retriever
retriever = get_kb_retriever()
# Retrieve documents
retrieved_docs = await retriever.ainvoke(query)
# Create document summaries for the state
document_summaries = []
for i, doc in enumerate(retrieved_docs, 1):
summary = {
"id": doc.metadata.get("id", f"doc-{i}"),
"source": doc.metadata.get("source", "Unknown"),
"title": doc.metadata.get("title", f"Document {i}"),
"content": doc.page_content,
"relevance_score": doc.metadata.get("score", 0),
}
document_summaries.append(summary)
logger.info(f"Retrieved {len(document_summaries)} documents for query: {query[:50]}...")
return {"retrieved_documents": document_summaries, "messages": []}
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
return {"retrieved_documents": [], "messages": []}
async def prepare_augmented_prompt(state: AgentState, config: RunnableConfig) -> AgentState:
"""Prepare a prompt augmented with retrieved document content."""
# Get retrieved documents
documents = state.get("retrieved_documents", [])
if not documents:
return {"messages": []}
# Format retrieved documents for the model
formatted_docs = "\n\n".join(
[
f"--- Document {i + 1} ---\n"
f"Source: {doc.get('source', 'Unknown')}\n"
f"Title: {doc.get('title', 'Unknown')}\n\n"
f"{doc.get('content', '')}"
for i, doc in enumerate(documents)
]
)
# Store formatted documents in the state
return {"kb_documents": formatted_docs, "messages": []}
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
"""Generate a response based on the retrieved documents."""
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)
return {"messages": [response]}
# Define the graph
agent = StateGraph(AgentState)
# Add nodes
agent.add_node("retrieve_documents", retrieve_documents)
agent.add_node("prepare_augmented_prompt", prepare_augmented_prompt)
agent.add_node("model", acall_model)
# Set entry point
agent.set_entry_point("retrieve_documents")
# Add edges to define the flow
agent.add_edge("retrieve_documents", "prepare_augmented_prompt")
agent.add_edge("prepare_augmented_prompt", "model")
agent.add_edge("model", END)
# Compile the agent
kb_agent = agent.compile()
|