File size: 5,830 Bytes
cff1a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
LangGraph Workflow for Insurance RAG System.
Implements deterministic, compliance-focused retrieval with specialized nodes.
"""

from langgraph.graph import StateGraph, END
from agents.states import AgentState
from agents.nodes import nodes


def build_rag_workflow() -> StateGraph:
    """
    Builds the LangGraph workflow with the following flow:
    
    query_rewriter β†’ query_classifier β†’ entity_extractor β†’ retrieval_router
                              ↓
                    [conditional routing by intent]
                              ↓
        β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
        β”‚ list_plans: listing_agent β†’ guardrail              β”‚
        β”‚ plan_details: retriever β†’ aggregator β†’ retrieval β†’ guardrail β”‚
        β”‚ compare_plans: retriever β†’ aggregator β†’ comparison β†’ guardrail β”‚
        β”‚ recommendation: retriever β†’ aggregator β†’ advisory β†’ guardrail β”‚
        β”‚ general_query: retriever β†’ aggregator β†’ faq β†’ guardrail β”‚
        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    """
    
    workflow = StateGraph(AgentState)
    
    # =========================================================================
    # Add all nodes
    # =========================================================================
    
    # Pre-processing nodes
    workflow.add_node("query_rewriter", nodes.query_rewriter_node)
    workflow.add_node("query_classifier", nodes.query_classifier_node)
    workflow.add_node("entity_extractor", nodes.entity_extractor_node)
    workflow.add_node("retrieval_router", nodes.retrieval_router_node)
    
    # Retrieval nodes
    workflow.add_node("retriever", nodes.retriever_node)
    workflow.add_node("plan_aggregator", nodes.plan_aggregator_node)
    
    # Agent nodes
    workflow.add_node("listing_agent", nodes.listing_agent)
    workflow.add_node("retrieval_agent", nodes.retrieval_agent)
    workflow.add_node("advisory_agent", nodes.advisory_agent)
    workflow.add_node("faq_agent", nodes.faq_agent)
    
    # Post-processing
    workflow.add_node("guardrail", nodes.guardrail_node)
    
    # =========================================================================
    # Define edges
    # =========================================================================
    
    # Entry point
    workflow.set_entry_point("query_rewriter")
    
    # Linear pre-processing chain
    workflow.add_edge("query_rewriter", "query_classifier")
    workflow.add_edge("query_classifier", "entity_extractor")
    workflow.add_edge("entity_extractor", "retrieval_router")
    
    # Conditional routing based on intent
    def route_by_intent(state: AgentState) -> str:
        """Route to appropriate handler based on classified intent."""
        intent = state.get("intent", "plan_details")
        
        if intent == "list_plans":
            return "listing_agent"
        else:
            return "retriever"
    
    workflow.add_conditional_edges(
        "retrieval_router",
        route_by_intent,
        {
            "listing_agent": "listing_agent",
            "retriever": "retriever"
        }
    )
    
    # Listing agent goes directly to guardrail
    workflow.add_edge("listing_agent", "guardrail")
    
    # Retriever always goes to aggregator
    workflow.add_edge("retriever", "plan_aggregator")
    
    # Aggregator routes to appropriate agent based on intent
    def route_to_agent(state: AgentState) -> str:
        """Route from aggregator to the appropriate agent."""
        intent = state.get("intent", "plan_details")
        
        route_map = {
            "plan_details": "retrieval_agent",
            "recommendation": "advisory_agent",
            "general_query": "faq_agent"
        }
        
        return route_map.get(intent, "retrieval_agent")
    
    workflow.add_conditional_edges(
        "plan_aggregator",
        route_to_agent,
        {
            "retrieval_agent": "retrieval_agent",
            "advisory_agent": "advisory_agent",
            "faq_agent": "faq_agent"
        }
    )
    
    # All agents end at guardrail
    workflow.add_edge("retrieval_agent", "guardrail")
    workflow.add_edge("advisory_agent", "guardrail")
    workflow.add_edge("faq_agent", "guardrail")
    
    # Guardrail ends the workflow
    workflow.add_edge("guardrail", END)
    
    return workflow


# Build and compile the workflow
workflow = build_rag_workflow()
app = workflow.compile()


if __name__ == "__main__":
    # Test the graph
    print("Graph compiled successfully!")
    
    # Test cases
    test_queries = [
        "List all term plans from Tata AIA",
        "Explain the TATA AIA Smart Value Income plan",
        "Compare Tata AIA vs Edelweiss term plans",
        "Suggest a plan for 30-year-old non-smoker with 1Cr cover"
    ]
    
    for query in test_queries:
        print(f"\n{'='*60}")
        print(f"Query: {query}")
        print('='*60)
        
        initial_state = {
            "input": query,
            "chat_history": [],
            "intent": "",
            "extracted_entities": {},
            "metadata_filters": {},
            "retrieval_strategy": "",
            "context": [],
            "retrieved_chunks": {},
            "reasoning_output": "",
            "answer": "",
            "next_step": ""
        }
        
        try:
            result = app.invoke(initial_state)
            print(f"Intent: {result.get('intent')}")
            print(f"Answer: {result.get('answer', '')[:500]}...")
        except Exception as e:
            print(f"Error: {e}")