langgraph-rag-agent / src /agent_workflow.py
Harsh-1132's picture
hf
a77376b
"""
LangGraph workflow implementation for the RAG Q&A Agent.
Defines the agent graph with plan, retrieve, answer, and reflect nodes.
"""
from typing import TypedDict, Annotated, Dict, Any, List
from langgraph.graph import StateGraph, END
from rag_pipeline import RAGPipeline
from llm_utils import LLMHandler
from reflection import ReflectionEvaluator
import operator
# Define the agent state
class AgentState(TypedDict):
"""State passed between nodes in the agent workflow."""
query: str
plan: str
needs_retrieval: bool
retrieved_context: str
retrieved_chunks: List[Dict[str, Any]]
answer: str
reflection: Dict[str, Any]
final_response: str
iteration: int
class RAGAgent:
"""LangGraph-based RAG Q&A Agent with reflection."""
def __init__(
self,
rag_pipeline: RAGPipeline,
llm_handler: LLMHandler,
reflection_evaluator: ReflectionEvaluator,
max_iterations: int = 2
):
"""
Initialize the RAG agent.
Args:
rag_pipeline: RAG pipeline for retrieval
llm_handler: LLM handler for generation
reflection_evaluator: Reflection evaluator
max_iterations: Maximum reflection iterations
"""
self.rag_pipeline = rag_pipeline
self.llm_handler = llm_handler
self.reflection_evaluator = reflection_evaluator
self.max_iterations = max_iterations
# Build the graph
self.graph = self._build_graph()
print("✓ RAG Agent workflow initialized")
def _build_graph(self):
"""Build the LangGraph workflow."""
# Create state graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("plan", self.plan_node)
workflow.add_node("retrieve", self.retrieve_node)
workflow.add_node("answer", self.answer_node)
workflow.add_node("reflect", self.reflect_node)
# Define edges
workflow.set_entry_point("plan")
# Plan -> Retrieve or Answer
workflow.add_conditional_edges(
"plan",
self.should_retrieve,
{
True: "retrieve",
False: "answer"
}
)
# Retrieve -> Answer
workflow.add_edge("retrieve", "answer")
# Answer -> Reflect
workflow.add_edge("answer", "reflect")
# Reflect -> End or Answer (for regeneration)
workflow.add_conditional_edges(
"reflect",
self.should_regenerate,
{
"accept": END,
"regenerate": "answer",
"end": END
}
)
return workflow.compile()
def plan_node(self, state: AgentState) -> AgentState:
"""
Planning node: Analyze query and decide if retrieval is needed.
Args:
state: Current agent state
Returns:
Updated state with plan
"""
print("\n" + "="*60)
print("📋 NODE: PLAN")
print("="*60 + "\n")
query = state["query"]
print(f"Query: {query}\n")
# Use LLM to analyze query and create a plan
planning_prompt = f"""Analyze the following user query and determine if it requires retrieving information from a knowledge base.
User Query: "{query}"
Consider:
1. Is this a factual question that would benefit from specific documentation or knowledge?
2. Is this a general question that can be answered without specific context?
3. Does this query ask about specific concepts, technologies, or topics?
Respond in the following format:
NEEDS_RETRIEVAL: [YES/NO]
REASONING: [Brief explanation]
PLAN: [How you will approach answering this query]"""
system_message = "You are a query planning agent. Analyze queries and determine the best approach to answer them."
plan_response = self.llm_handler.generate(
planning_prompt,
system_message
)
# Parse response
needs_retrieval = "YES" in plan_response.upper().split("NEEDS_RETRIEVAL:")[1].split("\n")[0] if "NEEDS_RETRIEVAL:" in plan_response.upper() else True
print(f"Plan Response:\n{plan_response}\n")
print(f"Needs Retrieval: {needs_retrieval}")
state["plan"] = plan_response
state["needs_retrieval"] = needs_retrieval
state["iteration"] = 0
print("\n" + "="*60 + "\n")
return state
def should_retrieve(self, state: AgentState) -> bool:
"""Conditional edge: Determine if retrieval is needed."""
return state["needs_retrieval"]
def retrieve_node(self, state: AgentState) -> AgentState:
"""
Retrieval node: Retrieve relevant context from vector store.
Args:
state: Current agent state
Returns:
Updated state with retrieved context
"""
print("\n" + "="*60)
print("🔍 NODE: RETRIEVE")
print("="*60 + "\n")
query = state["query"]
# Retrieve context
context, chunks = self.rag_pipeline.retrieve_context(query, top_k=3)
print(f"Retrieved {len(chunks)} relevant chunks\n")
# Display retrieved content preview
for i, chunk in enumerate(chunks):
preview = chunk['content'][:150] + "..." if len(chunk['content']) > 150 else chunk['content']
print(f"Chunk {i+1} Preview: {preview}\n")
state["retrieved_context"] = context
state["retrieved_chunks"] = chunks
print("="*60 + "\n")
return state
def answer_node(self, state: AgentState) -> AgentState:
"""
Answer generation node: Generate answer using LLM.
Args:
state: Current agent state
Returns:
Updated state with generated answer
"""
print("\n" + "="*60)
print("💬 NODE: ANSWER")
print("="*60 + "\n")
query = state["query"]
iteration = state.get("iteration", 0)
if iteration > 0:
print(f"[Regeneration attempt {iteration}]\n")
# Check if we have retrieved context
if state.get("retrieved_context"):
# Generate answer with context
context = state["retrieved_context"]
# Check if this is a regeneration with feedback
if "reflection" in state and iteration > 0:
feedback = state["reflection"]["reasoning"]
answer = self._generate_answer_with_feedback(query, context, feedback)
else:
answer = self.llm_handler.generate_with_context(
query,
context,
system_message="You are a helpful AI assistant. Answer questions accurately based on the provided context."
)
else:
# Generate answer without context
answer = self.llm_handler.generate(
query,
system_message="You are a helpful AI assistant. Answer questions concisely and accurately."
)
print(f"Generated Answer:\n{answer}\n")
state["answer"] = answer
print("="*60 + "\n")
return state
def _generate_answer_with_feedback(
self,
query: str,
context: str,
feedback: str
) -> str:
"""
Generate answer incorporating feedback from reflection.
Args:
query: User query
context: Retrieved context
feedback: Feedback from reflection
Returns:
Regenerated answer
"""
prompt = f"""The previous answer was not satisfactory. Here's the feedback:
{feedback}
Now, please generate a better answer to the following question using the context provided.
Context:
{context}
Question: {query}
Provide a comprehensive, accurate, and relevant answer that addresses the feedback."""
system_message = "You are a helpful AI assistant. Learn from feedback and provide improved answers."
return self.llm_handler.generate(prompt, system_message)
def reflect_node(self, state: AgentState) -> AgentState:
"""
Reflection node: Evaluate answer quality.
Args:
state: Current agent state
Returns:
Updated state with reflection results
"""
query = state["query"]
answer = state["answer"]
context = state.get("retrieved_context", "")
chunks = state.get("retrieved_chunks", [])
# Evaluate answer
reflection_result = self.reflection_evaluator.evaluate(
query,
answer,
context,
chunks
)
state["reflection"] = reflection_result
return state
def should_regenerate(self, state: AgentState) -> str:
"""
Conditional edge: Determine if answer should be regenerated.
Args:
state: Current agent state
Returns:
Next node or END
"""
reflection = state["reflection"]
iteration = state.get("iteration", 0)
recommendation = reflection.get("recommendation", "ACCEPT")
# Accept answer if it's good enough or we've hit max iterations
if recommendation == "ACCEPT" or iteration >= self.max_iterations:
state["final_response"] = state["answer"]
return "accept"
# Regenerate if rejected and we haven't hit max iterations
if recommendation == "REJECT" and iteration < self.max_iterations:
state["iteration"] = iteration + 1
print(f"\n⚠️ Answer rejected. Regenerating (iteration {state['iteration']})...\n")
return "regenerate"
# Otherwise, accept with partial relevance
state["final_response"] = state["answer"]
return "end"
def query(self, question: str) -> Dict[str, Any]:
"""
Process a query through the agent workflow.
Args:
question: User question
Returns:
Complete agent response with all state information
"""
print("\n" + "="*70)
print(" "*20 + "🤖 RAG Q&A AGENT 🤖")
print("="*70 + "\n")
print(f"User Query: {question}")
print("="*70)
# Initialize state
initial_state = AgentState(
query=question,
plan="",
needs_retrieval=True,
retrieved_context="",
retrieved_chunks=[],
answer="",
reflection={},
final_response="",
iteration=0
)
# Run the graph
final_state = self.graph.invoke(initial_state)
# Print final result
print("\n" + "="*70)
print("✅ FINAL RESPONSE")
print("="*70 + "\n")
print(final_state["final_response"])
print("\n" + "="*70 + "\n")
return final_state
def create_rag_agent(
rag_pipeline: RAGPipeline,
llm_handler: LLMHandler,
reflection_evaluator: ReflectionEvaluator,
max_iterations: int = 2
) -> RAGAgent:
"""
Create and return a RAG agent instance.
Args:
rag_pipeline: RAG pipeline for retrieval
llm_handler: LLM handler for generation
reflection_evaluator: Reflection evaluator
max_iterations: Maximum reflection iterations
Returns:
RAGAgent instance
"""
return RAGAgent(rag_pipeline, llm_handler, reflection_evaluator, max_iterations)