Spaces:
Sleeping
Sleeping
| """ | |
| LangGraph workflow implementation for the RAG Q&A Agent. | |
| Defines the agent graph with plan, retrieve, answer, and reflect nodes. | |
| """ | |
| from typing import TypedDict, Annotated, Dict, Any, List | |
| from langgraph.graph import StateGraph, END | |
| from rag_pipeline import RAGPipeline | |
| from llm_utils import LLMHandler | |
| from reflection import ReflectionEvaluator | |
| import operator | |
| # Define the agent state | |
| class AgentState(TypedDict): | |
| """State passed between nodes in the agent workflow.""" | |
| query: str | |
| plan: str | |
| needs_retrieval: bool | |
| retrieved_context: str | |
| retrieved_chunks: List[Dict[str, Any]] | |
| answer: str | |
| reflection: Dict[str, Any] | |
| final_response: str | |
| iteration: int | |
| class RAGAgent: | |
| """LangGraph-based RAG Q&A Agent with reflection.""" | |
| def __init__( | |
| self, | |
| rag_pipeline: RAGPipeline, | |
| llm_handler: LLMHandler, | |
| reflection_evaluator: ReflectionEvaluator, | |
| max_iterations: int = 2 | |
| ): | |
| """ | |
| Initialize the RAG agent. | |
| Args: | |
| rag_pipeline: RAG pipeline for retrieval | |
| llm_handler: LLM handler for generation | |
| reflection_evaluator: Reflection evaluator | |
| max_iterations: Maximum reflection iterations | |
| """ | |
| self.rag_pipeline = rag_pipeline | |
| self.llm_handler = llm_handler | |
| self.reflection_evaluator = reflection_evaluator | |
| self.max_iterations = max_iterations | |
| # Build the graph | |
| self.graph = self._build_graph() | |
| print("✓ RAG Agent workflow initialized") | |
| def _build_graph(self): | |
| """Build the LangGraph workflow.""" | |
| # Create state graph | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("plan", self.plan_node) | |
| workflow.add_node("retrieve", self.retrieve_node) | |
| workflow.add_node("answer", self.answer_node) | |
| workflow.add_node("reflect", self.reflect_node) | |
| # Define edges | |
| workflow.set_entry_point("plan") | |
| # Plan -> Retrieve or Answer | |
| workflow.add_conditional_edges( | |
| "plan", | |
| self.should_retrieve, | |
| { | |
| True: "retrieve", | |
| False: "answer" | |
| } | |
| ) | |
| # Retrieve -> Answer | |
| workflow.add_edge("retrieve", "answer") | |
| # Answer -> Reflect | |
| workflow.add_edge("answer", "reflect") | |
| # Reflect -> End or Answer (for regeneration) | |
| workflow.add_conditional_edges( | |
| "reflect", | |
| self.should_regenerate, | |
| { | |
| "accept": END, | |
| "regenerate": "answer", | |
| "end": END | |
| } | |
| ) | |
| return workflow.compile() | |
| def plan_node(self, state: AgentState) -> AgentState: | |
| """ | |
| Planning node: Analyze query and decide if retrieval is needed. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Updated state with plan | |
| """ | |
| print("\n" + "="*60) | |
| print("📋 NODE: PLAN") | |
| print("="*60 + "\n") | |
| query = state["query"] | |
| print(f"Query: {query}\n") | |
| # Use LLM to analyze query and create a plan | |
| planning_prompt = f"""Analyze the following user query and determine if it requires retrieving information from a knowledge base. | |
| User Query: "{query}" | |
| Consider: | |
| 1. Is this a factual question that would benefit from specific documentation or knowledge? | |
| 2. Is this a general question that can be answered without specific context? | |
| 3. Does this query ask about specific concepts, technologies, or topics? | |
| Respond in the following format: | |
| NEEDS_RETRIEVAL: [YES/NO] | |
| REASONING: [Brief explanation] | |
| PLAN: [How you will approach answering this query]""" | |
| system_message = "You are a query planning agent. Analyze queries and determine the best approach to answer them." | |
| plan_response = self.llm_handler.generate( | |
| planning_prompt, | |
| system_message | |
| ) | |
| # Parse response | |
| needs_retrieval = "YES" in plan_response.upper().split("NEEDS_RETRIEVAL:")[1].split("\n")[0] if "NEEDS_RETRIEVAL:" in plan_response.upper() else True | |
| print(f"Plan Response:\n{plan_response}\n") | |
| print(f"Needs Retrieval: {needs_retrieval}") | |
| state["plan"] = plan_response | |
| state["needs_retrieval"] = needs_retrieval | |
| state["iteration"] = 0 | |
| print("\n" + "="*60 + "\n") | |
| return state | |
| def should_retrieve(self, state: AgentState) -> bool: | |
| """Conditional edge: Determine if retrieval is needed.""" | |
| return state["needs_retrieval"] | |
| def retrieve_node(self, state: AgentState) -> AgentState: | |
| """ | |
| Retrieval node: Retrieve relevant context from vector store. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Updated state with retrieved context | |
| """ | |
| print("\n" + "="*60) | |
| print("🔍 NODE: RETRIEVE") | |
| print("="*60 + "\n") | |
| query = state["query"] | |
| # Retrieve context | |
| context, chunks = self.rag_pipeline.retrieve_context(query, top_k=3) | |
| print(f"Retrieved {len(chunks)} relevant chunks\n") | |
| # Display retrieved content preview | |
| for i, chunk in enumerate(chunks): | |
| preview = chunk['content'][:150] + "..." if len(chunk['content']) > 150 else chunk['content'] | |
| print(f"Chunk {i+1} Preview: {preview}\n") | |
| state["retrieved_context"] = context | |
| state["retrieved_chunks"] = chunks | |
| print("="*60 + "\n") | |
| return state | |
| def answer_node(self, state: AgentState) -> AgentState: | |
| """ | |
| Answer generation node: Generate answer using LLM. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Updated state with generated answer | |
| """ | |
| print("\n" + "="*60) | |
| print("💬 NODE: ANSWER") | |
| print("="*60 + "\n") | |
| query = state["query"] | |
| iteration = state.get("iteration", 0) | |
| if iteration > 0: | |
| print(f"[Regeneration attempt {iteration}]\n") | |
| # Check if we have retrieved context | |
| if state.get("retrieved_context"): | |
| # Generate answer with context | |
| context = state["retrieved_context"] | |
| # Check if this is a regeneration with feedback | |
| if "reflection" in state and iteration > 0: | |
| feedback = state["reflection"]["reasoning"] | |
| answer = self._generate_answer_with_feedback(query, context, feedback) | |
| else: | |
| answer = self.llm_handler.generate_with_context( | |
| query, | |
| context, | |
| system_message="You are a helpful AI assistant. Answer questions accurately based on the provided context." | |
| ) | |
| else: | |
| # Generate answer without context | |
| answer = self.llm_handler.generate( | |
| query, | |
| system_message="You are a helpful AI assistant. Answer questions concisely and accurately." | |
| ) | |
| print(f"Generated Answer:\n{answer}\n") | |
| state["answer"] = answer | |
| print("="*60 + "\n") | |
| return state | |
| def _generate_answer_with_feedback( | |
| self, | |
| query: str, | |
| context: str, | |
| feedback: str | |
| ) -> str: | |
| """ | |
| Generate answer incorporating feedback from reflection. | |
| Args: | |
| query: User query | |
| context: Retrieved context | |
| feedback: Feedback from reflection | |
| Returns: | |
| Regenerated answer | |
| """ | |
| prompt = f"""The previous answer was not satisfactory. Here's the feedback: | |
| {feedback} | |
| Now, please generate a better answer to the following question using the context provided. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Provide a comprehensive, accurate, and relevant answer that addresses the feedback.""" | |
| system_message = "You are a helpful AI assistant. Learn from feedback and provide improved answers." | |
| return self.llm_handler.generate(prompt, system_message) | |
| def reflect_node(self, state: AgentState) -> AgentState: | |
| """ | |
| Reflection node: Evaluate answer quality. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Updated state with reflection results | |
| """ | |
| query = state["query"] | |
| answer = state["answer"] | |
| context = state.get("retrieved_context", "") | |
| chunks = state.get("retrieved_chunks", []) | |
| # Evaluate answer | |
| reflection_result = self.reflection_evaluator.evaluate( | |
| query, | |
| answer, | |
| context, | |
| chunks | |
| ) | |
| state["reflection"] = reflection_result | |
| return state | |
| def should_regenerate(self, state: AgentState) -> str: | |
| """ | |
| Conditional edge: Determine if answer should be regenerated. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| Next node or END | |
| """ | |
| reflection = state["reflection"] | |
| iteration = state.get("iteration", 0) | |
| recommendation = reflection.get("recommendation", "ACCEPT") | |
| # Accept answer if it's good enough or we've hit max iterations | |
| if recommendation == "ACCEPT" or iteration >= self.max_iterations: | |
| state["final_response"] = state["answer"] | |
| return "accept" | |
| # Regenerate if rejected and we haven't hit max iterations | |
| if recommendation == "REJECT" and iteration < self.max_iterations: | |
| state["iteration"] = iteration + 1 | |
| print(f"\n⚠️ Answer rejected. Regenerating (iteration {state['iteration']})...\n") | |
| return "regenerate" | |
| # Otherwise, accept with partial relevance | |
| state["final_response"] = state["answer"] | |
| return "end" | |
| def query(self, question: str) -> Dict[str, Any]: | |
| """ | |
| Process a query through the agent workflow. | |
| Args: | |
| question: User question | |
| Returns: | |
| Complete agent response with all state information | |
| """ | |
| print("\n" + "="*70) | |
| print(" "*20 + "🤖 RAG Q&A AGENT 🤖") | |
| print("="*70 + "\n") | |
| print(f"User Query: {question}") | |
| print("="*70) | |
| # Initialize state | |
| initial_state = AgentState( | |
| query=question, | |
| plan="", | |
| needs_retrieval=True, | |
| retrieved_context="", | |
| retrieved_chunks=[], | |
| answer="", | |
| reflection={}, | |
| final_response="", | |
| iteration=0 | |
| ) | |
| # Run the graph | |
| final_state = self.graph.invoke(initial_state) | |
| # Print final result | |
| print("\n" + "="*70) | |
| print("✅ FINAL RESPONSE") | |
| print("="*70 + "\n") | |
| print(final_state["final_response"]) | |
| print("\n" + "="*70 + "\n") | |
| return final_state | |
| def create_rag_agent( | |
| rag_pipeline: RAGPipeline, | |
| llm_handler: LLMHandler, | |
| reflection_evaluator: ReflectionEvaluator, | |
| max_iterations: int = 2 | |
| ) -> RAGAgent: | |
| """ | |
| Create and return a RAG agent instance. | |
| Args: | |
| rag_pipeline: RAG pipeline for retrieval | |
| llm_handler: LLM handler for generation | |
| reflection_evaluator: Reflection evaluator | |
| max_iterations: Maximum reflection iterations | |
| Returns: | |
| RAGAgent instance | |
| """ | |
| return RAGAgent(rag_pipeline, llm_handler, reflection_evaluator, max_iterations) |