Spaces:
Running
Running
File size: 5,830 Bytes
cff1a2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""
LangGraph Workflow for Insurance RAG System.
Implements deterministic, compliance-focused retrieval with specialized nodes.
"""
from langgraph.graph import StateGraph, END
from agents.states import AgentState
from agents.nodes import nodes
def build_rag_workflow() -> StateGraph:
"""
Builds the LangGraph workflow with the following flow:
query_rewriter β query_classifier β entity_extractor β retrieval_router
β
[conditional routing by intent]
β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β list_plans: listing_agent β guardrail β
β plan_details: retriever β aggregator β retrieval β guardrail β
β compare_plans: retriever β aggregator β comparison β guardrail β
β recommendation: retriever β aggregator β advisory β guardrail β
β general_query: retriever β aggregator β faq β guardrail β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
"""
workflow = StateGraph(AgentState)
# =========================================================================
# Add all nodes
# =========================================================================
# Pre-processing nodes
workflow.add_node("query_rewriter", nodes.query_rewriter_node)
workflow.add_node("query_classifier", nodes.query_classifier_node)
workflow.add_node("entity_extractor", nodes.entity_extractor_node)
workflow.add_node("retrieval_router", nodes.retrieval_router_node)
# Retrieval nodes
workflow.add_node("retriever", nodes.retriever_node)
workflow.add_node("plan_aggregator", nodes.plan_aggregator_node)
# Agent nodes
workflow.add_node("listing_agent", nodes.listing_agent)
workflow.add_node("retrieval_agent", nodes.retrieval_agent)
workflow.add_node("advisory_agent", nodes.advisory_agent)
workflow.add_node("faq_agent", nodes.faq_agent)
# Post-processing
workflow.add_node("guardrail", nodes.guardrail_node)
# =========================================================================
# Define edges
# =========================================================================
# Entry point
workflow.set_entry_point("query_rewriter")
# Linear pre-processing chain
workflow.add_edge("query_rewriter", "query_classifier")
workflow.add_edge("query_classifier", "entity_extractor")
workflow.add_edge("entity_extractor", "retrieval_router")
# Conditional routing based on intent
def route_by_intent(state: AgentState) -> str:
"""Route to appropriate handler based on classified intent."""
intent = state.get("intent", "plan_details")
if intent == "list_plans":
return "listing_agent"
else:
return "retriever"
workflow.add_conditional_edges(
"retrieval_router",
route_by_intent,
{
"listing_agent": "listing_agent",
"retriever": "retriever"
}
)
# Listing agent goes directly to guardrail
workflow.add_edge("listing_agent", "guardrail")
# Retriever always goes to aggregator
workflow.add_edge("retriever", "plan_aggregator")
# Aggregator routes to appropriate agent based on intent
def route_to_agent(state: AgentState) -> str:
"""Route from aggregator to the appropriate agent."""
intent = state.get("intent", "plan_details")
route_map = {
"plan_details": "retrieval_agent",
"recommendation": "advisory_agent",
"general_query": "faq_agent"
}
return route_map.get(intent, "retrieval_agent")
workflow.add_conditional_edges(
"plan_aggregator",
route_to_agent,
{
"retrieval_agent": "retrieval_agent",
"advisory_agent": "advisory_agent",
"faq_agent": "faq_agent"
}
)
# All agents end at guardrail
workflow.add_edge("retrieval_agent", "guardrail")
workflow.add_edge("advisory_agent", "guardrail")
workflow.add_edge("faq_agent", "guardrail")
# Guardrail ends the workflow
workflow.add_edge("guardrail", END)
return workflow
# Build and compile the workflow
workflow = build_rag_workflow()
app = workflow.compile()
if __name__ == "__main__":
# Test the graph
print("Graph compiled successfully!")
# Test cases
test_queries = [
"List all term plans from Tata AIA",
"Explain the TATA AIA Smart Value Income plan",
"Compare Tata AIA vs Edelweiss term plans",
"Suggest a plan for 30-year-old non-smoker with 1Cr cover"
]
for query in test_queries:
print(f"\n{'='*60}")
print(f"Query: {query}")
print('='*60)
initial_state = {
"input": query,
"chat_history": [],
"intent": "",
"extracted_entities": {},
"metadata_filters": {},
"retrieval_strategy": "",
"context": [],
"retrieved_chunks": {},
"reasoning_output": "",
"answer": "",
"next_step": ""
}
try:
result = app.invoke(initial_state)
print(f"Intent: {result.get('intent')}")
print(f"Answer: {result.get('answer', '')[:500]}...")
except Exception as e:
print(f"Error: {e}")
|