""" LangGraph workflow implementation for the RAG Q&A Agent. Defines the agent graph with plan, retrieve, answer, and reflect nodes. """ from typing import TypedDict, Annotated, Dict, Any, List from langgraph.graph import StateGraph, END from rag_pipeline import RAGPipeline from llm_utils import LLMHandler from reflection import ReflectionEvaluator import operator # Define the agent state class AgentState(TypedDict): """State passed between nodes in the agent workflow.""" query: str plan: str needs_retrieval: bool retrieved_context: str retrieved_chunks: List[Dict[str, Any]] answer: str reflection: Dict[str, Any] final_response: str iteration: int class RAGAgent: """LangGraph-based RAG Q&A Agent with reflection.""" def __init__( self, rag_pipeline: RAGPipeline, llm_handler: LLMHandler, reflection_evaluator: ReflectionEvaluator, max_iterations: int = 2 ): """ Initialize the RAG agent. Args: rag_pipeline: RAG pipeline for retrieval llm_handler: LLM handler for generation reflection_evaluator: Reflection evaluator max_iterations: Maximum reflection iterations """ self.rag_pipeline = rag_pipeline self.llm_handler = llm_handler self.reflection_evaluator = reflection_evaluator self.max_iterations = max_iterations # Build the graph self.graph = self._build_graph() print("āœ“ RAG Agent workflow initialized") def _build_graph(self): """Build the LangGraph workflow.""" # Create state graph workflow = StateGraph(AgentState) # Add nodes workflow.add_node("plan", self.plan_node) workflow.add_node("retrieve", self.retrieve_node) workflow.add_node("answer", self.answer_node) workflow.add_node("reflect", self.reflect_node) # Define edges workflow.set_entry_point("plan") # Plan -> Retrieve or Answer workflow.add_conditional_edges( "plan", self.should_retrieve, { True: "retrieve", False: "answer" } ) # Retrieve -> Answer workflow.add_edge("retrieve", "answer") # Answer -> Reflect workflow.add_edge("answer", "reflect") # Reflect -> End or Answer (for regeneration) workflow.add_conditional_edges( "reflect", self.should_regenerate, { "accept": END, "regenerate": "answer", "end": END } ) return workflow.compile() def plan_node(self, state: AgentState) -> AgentState: """ Planning node: Analyze query and decide if retrieval is needed. Args: state: Current agent state Returns: Updated state with plan """ print("\n" + "="*60) print("šŸ“‹ NODE: PLAN") print("="*60 + "\n") query = state["query"] print(f"Query: {query}\n") # Use LLM to analyze query and create a plan planning_prompt = f"""Analyze the following user query and determine if it requires retrieving information from a knowledge base. User Query: "{query}" Consider: 1. Is this a factual question that would benefit from specific documentation or knowledge? 2. Is this a general question that can be answered without specific context? 3. Does this query ask about specific concepts, technologies, or topics? Respond in the following format: NEEDS_RETRIEVAL: [YES/NO] REASONING: [Brief explanation] PLAN: [How you will approach answering this query]""" system_message = "You are a query planning agent. Analyze queries and determine the best approach to answer them." plan_response = self.llm_handler.generate( planning_prompt, system_message ) # Parse response needs_retrieval = "YES" in plan_response.upper().split("NEEDS_RETRIEVAL:")[1].split("\n")[0] if "NEEDS_RETRIEVAL:" in plan_response.upper() else True print(f"Plan Response:\n{plan_response}\n") print(f"Needs Retrieval: {needs_retrieval}") state["plan"] = plan_response state["needs_retrieval"] = needs_retrieval state["iteration"] = 0 print("\n" + "="*60 + "\n") return state def should_retrieve(self, state: AgentState) -> bool: """Conditional edge: Determine if retrieval is needed.""" return state["needs_retrieval"] def retrieve_node(self, state: AgentState) -> AgentState: """ Retrieval node: Retrieve relevant context from vector store. Args: state: Current agent state Returns: Updated state with retrieved context """ print("\n" + "="*60) print("šŸ” NODE: RETRIEVE") print("="*60 + "\n") query = state["query"] # Retrieve context context, chunks = self.rag_pipeline.retrieve_context(query, top_k=3) print(f"Retrieved {len(chunks)} relevant chunks\n") # Display retrieved content preview for i, chunk in enumerate(chunks): preview = chunk['content'][:150] + "..." if len(chunk['content']) > 150 else chunk['content'] print(f"Chunk {i+1} Preview: {preview}\n") state["retrieved_context"] = context state["retrieved_chunks"] = chunks print("="*60 + "\n") return state def answer_node(self, state: AgentState) -> AgentState: """ Answer generation node: Generate answer using LLM. Args: state: Current agent state Returns: Updated state with generated answer """ print("\n" + "="*60) print("šŸ’¬ NODE: ANSWER") print("="*60 + "\n") query = state["query"] iteration = state.get("iteration", 0) if iteration > 0: print(f"[Regeneration attempt {iteration}]\n") # Check if we have retrieved context if state.get("retrieved_context"): # Generate answer with context context = state["retrieved_context"] # Check if this is a regeneration with feedback if "reflection" in state and iteration > 0: feedback = state["reflection"]["reasoning"] answer = self._generate_answer_with_feedback(query, context, feedback) else: answer = self.llm_handler.generate_with_context( query, context, system_message="You are a helpful AI assistant. Answer questions accurately based on the provided context." ) else: # Generate answer without context answer = self.llm_handler.generate( query, system_message="You are a helpful AI assistant. Answer questions concisely and accurately." ) print(f"Generated Answer:\n{answer}\n") state["answer"] = answer print("="*60 + "\n") return state def _generate_answer_with_feedback( self, query: str, context: str, feedback: str ) -> str: """ Generate answer incorporating feedback from reflection. Args: query: User query context: Retrieved context feedback: Feedback from reflection Returns: Regenerated answer """ prompt = f"""The previous answer was not satisfactory. Here's the feedback: {feedback} Now, please generate a better answer to the following question using the context provided. Context: {context} Question: {query} Provide a comprehensive, accurate, and relevant answer that addresses the feedback.""" system_message = "You are a helpful AI assistant. Learn from feedback and provide improved answers." return self.llm_handler.generate(prompt, system_message) def reflect_node(self, state: AgentState) -> AgentState: """ Reflection node: Evaluate answer quality. Args: state: Current agent state Returns: Updated state with reflection results """ query = state["query"] answer = state["answer"] context = state.get("retrieved_context", "") chunks = state.get("retrieved_chunks", []) # Evaluate answer reflection_result = self.reflection_evaluator.evaluate( query, answer, context, chunks ) state["reflection"] = reflection_result return state def should_regenerate(self, state: AgentState) -> str: """ Conditional edge: Determine if answer should be regenerated. Args: state: Current agent state Returns: Next node or END """ reflection = state["reflection"] iteration = state.get("iteration", 0) recommendation = reflection.get("recommendation", "ACCEPT") # Accept answer if it's good enough or we've hit max iterations if recommendation == "ACCEPT" or iteration >= self.max_iterations: state["final_response"] = state["answer"] return "accept" # Regenerate if rejected and we haven't hit max iterations if recommendation == "REJECT" and iteration < self.max_iterations: state["iteration"] = iteration + 1 print(f"\nāš ļø Answer rejected. Regenerating (iteration {state['iteration']})...\n") return "regenerate" # Otherwise, accept with partial relevance state["final_response"] = state["answer"] return "end" def query(self, question: str) -> Dict[str, Any]: """ Process a query through the agent workflow. Args: question: User question Returns: Complete agent response with all state information """ print("\n" + "="*70) print(" "*20 + "šŸ¤– RAG Q&A AGENT šŸ¤–") print("="*70 + "\n") print(f"User Query: {question}") print("="*70) # Initialize state initial_state = AgentState( query=question, plan="", needs_retrieval=True, retrieved_context="", retrieved_chunks=[], answer="", reflection={}, final_response="", iteration=0 ) # Run the graph final_state = self.graph.invoke(initial_state) # Print final result print("\n" + "="*70) print("āœ… FINAL RESPONSE") print("="*70 + "\n") print(final_state["final_response"]) print("\n" + "="*70 + "\n") return final_state def create_rag_agent( rag_pipeline: RAGPipeline, llm_handler: LLMHandler, reflection_evaluator: ReflectionEvaluator, max_iterations: int = 2 ) -> RAGAgent: """ Create and return a RAG agent instance. Args: rag_pipeline: RAG pipeline for retrieval llm_handler: LLM handler for generation reflection_evaluator: Reflection evaluator max_iterations: Maximum reflection iterations Returns: RAGAgent instance """ return RAGAgent(rag_pipeline, llm_handler, reflection_evaluator, max_iterations)