import logging import os from typing import Any from langchain_aws import AmazonKnowledgeBasesRetriever from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable from langchain_core.runnables.base import RunnableSequence from langgraph.graph import END, MessagesState, StateGraph from langgraph.managed import RemainingSteps from core import get_model, settings logger = logging.getLogger(__name__) # Define the state class AgentState(MessagesState, total=False): """State for Knowledge Base agent.""" remaining_steps: RemainingSteps retrieved_documents: list[dict[str, Any]] kb_documents: str # Create the retriever def get_kb_retriever(): """Create and return a Knowledge Base retriever instance.""" # Get the Knowledge Base ID from environment kb_id = os.environ.get("AWS_KB_ID", "") if not kb_id: raise ValueError("AWS_KB_ID environment variable must be set") # Create the retriever with the specified Knowledge Base ID retriever = AmazonKnowledgeBasesRetriever( knowledge_base_id=kb_id, retrieval_config={ "vectorSearchConfiguration": { "numberOfResults": 3, } }, ) return retriever def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: """Wrap the model with a system prompt for the Knowledge Base agent.""" def create_system_message(state): base_prompt = """You are a helpful assistant that provides accurate information based on retrieved documents. You will receive a query along with relevant documents retrieved from a knowledge base. Use these documents to inform your response. Follow these guidelines: 1. Base your answer primarily on the retrieved documents 2. If the documents contain the answer, provide it clearly and concisely 3. If the documents are insufficient, state that you don't have enough information 4. Never make up facts or information not present in the documents 5. Always cite the source documents when referring to specific information 6. If the documents contradict each other, acknowledge this and explain the different perspectives Format your response in a clear, conversational manner. Use markdown formatting when appropriate. """ # Check if documents were retrieved if "kb_documents" in state: # Append document information to the system prompt document_prompt = f"\n\nI've retrieved the following documents that may be relevant to the query:\n\n{state['kb_documents']}\n\nPlease use these documents to inform your response to the user's query. Only use information from these documents and clearly indicate when you are unsure." return [SystemMessage(content=base_prompt + document_prompt)] + state["messages"] else: # No documents were retrieved no_docs_prompt = ( "\n\nNo relevant documents were found in the knowledge base for this query." ) return [SystemMessage(content=base_prompt + no_docs_prompt)] + state["messages"] preprocessor = RunnableLambda( create_system_message, name="StateModifier", ) return RunnableSequence(preprocessor, model) async def retrieve_documents(state: AgentState, config: RunnableConfig) -> AgentState: """Retrieve relevant documents from the knowledge base.""" # Get the last human message human_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)] if not human_messages: # Include messages from original state return {"messages": [], "retrieved_documents": []} # Use the last human message as the query query = human_messages[-1].content try: # Initialize the retriever retriever = get_kb_retriever() # Retrieve documents retrieved_docs = await retriever.ainvoke(query) # Create document summaries for the state document_summaries = [] for i, doc in enumerate(retrieved_docs, 1): summary = { "id": doc.metadata.get("id", f"doc-{i}"), "source": doc.metadata.get("source", "Unknown"), "title": doc.metadata.get("title", f"Document {i}"), "content": doc.page_content, "relevance_score": doc.metadata.get("score", 0), } document_summaries.append(summary) logger.info(f"Retrieved {len(document_summaries)} documents for query: {query[:50]}...") return {"retrieved_documents": document_summaries, "messages": []} except Exception as e: logger.error(f"Error retrieving documents: {str(e)}") return {"retrieved_documents": [], "messages": []} async def prepare_augmented_prompt(state: AgentState, config: RunnableConfig) -> AgentState: """Prepare a prompt augmented with retrieved document content.""" # Get retrieved documents documents = state.get("retrieved_documents", []) if not documents: return {"messages": []} # Format retrieved documents for the model formatted_docs = "\n\n".join( [ f"--- Document {i + 1} ---\n" f"Source: {doc.get('source', 'Unknown')}\n" f"Title: {doc.get('title', 'Unknown')}\n\n" f"{doc.get('content', '')}" for i, doc in enumerate(documents) ] ) # Store formatted documents in the state return {"kb_documents": formatted_docs, "messages": []} async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: """Generate a response based on the retrieved documents.""" m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) model_runnable = wrap_model(m) response = await model_runnable.ainvoke(state, config) return {"messages": [response]} # Define the graph agent = StateGraph(AgentState) # Add nodes agent.add_node("retrieve_documents", retrieve_documents) agent.add_node("prepare_augmented_prompt", prepare_augmented_prompt) agent.add_node("model", acall_model) # Set entry point agent.set_entry_point("retrieve_documents") # Add edges to define the flow agent.add_edge("retrieve_documents", "prepare_augmented_prompt") agent.add_edge("prepare_augmented_prompt", "model") agent.add_edge("model", END) # Compile the agent kb_agent = agent.compile()