Sulaiman8 commited on
Commit
aab5cdb
·
verified ·
1 Parent(s): 9026943

Update langgraph_pipeline.py

Browse files
Files changed (1) hide show
  1. langgraph_pipeline.py +197 -200
langgraph_pipeline.py CHANGED
@@ -1,200 +1,197 @@
1
- from langgraph.graph import StateGraph, END
2
- from typing import Literal
3
- from data import debug_print
4
- from langchain_core.messages import AIMessage
5
- from langgraph.prebuilt import ToolNode,tools_condition
6
- from nodes.agent import agent_node,TOOLS
7
- from nodes.intent import oos_handler_node,general_info_handler_node,intent_classifier_node,CreditCardState
8
- from nodes.format import format_output_node
9
- from nodes.compare import compare_node_fn
10
- from nodes.chat import router_node,tool_node,expert_agent_node
11
- from recommender.graph_retrieval import neo4j_error_handler_node,neo4j_retrieval_node
12
- from recommender.vectordb import query_refiner_node
13
- from recommender.vectordb_retrieval import ranked_card_retrieval_node
14
-
15
-
16
- # Main Graph flow
17
- graph = StateGraph(CreditCardState)
18
- graph.add_node("intent_classifier", intent_classifier_node)
19
- graph.add_node("general_info_handler", general_info_handler_node)
20
- graph.add_node("oos_handler", oos_handler_node)
21
- graph.add_node("query_refiner", query_refiner_node)
22
- graph.add_node("neo4j_retriever", neo4j_retrieval_node)
23
- graph.add_node("neo4j_error_handler", neo4j_error_handler_node)
24
- graph.add_node("ranked_card_retrieval", ranked_card_retrieval_node)
25
- graph.add_node("agent", agent_node)
26
- graph.add_node("format_output", format_output_node)
27
-
28
- graph.set_entry_point("intent_classifier")
29
-
30
- def route_after_intent_classification(state: CreditCardState):
31
- intent = state["intent"]
32
- debug_print("ROUTE", f"Intent classification routing with intent: '{intent}'")
33
-
34
- if intent == "credit-card-recommendation":
35
- return "query_refiner"
36
- elif intent == "general-credit-related":
37
- return "general_info_handler"
38
- else:
39
- return "oos_handler"
40
-
41
- def route_after_format_output(state: CreditCardState):
42
- if state.get("trigger_compare", False):
43
- return "compare_node"
44
- elif state.get("trigger_chat", False):
45
- return "chat_node"
46
- else:
47
- return END
48
-
49
- graph.add_conditional_edges(
50
- "intent_classifier",
51
- route_after_intent_classification,
52
- {
53
- "query_refiner": "query_refiner",
54
- "general_info_handler": "general_info_handler",
55
- "oos_handler": "oos_handler",
56
- },
57
- )
58
-
59
- graph.add_edge("general_info_handler", END)
60
- graph.add_edge("oos_handler", END)
61
- graph.add_edge("query_refiner", "neo4j_retriever")
62
-
63
- def route_after_neo4j_retriever(state: CreditCardState):
64
- debug_print("ROUTE", f"neo4j_error: {state.get('neo4j_error')}")
65
- if state.get("neo4j_error", False):
66
- return "neo4j_error_handler"
67
- else:
68
- return "ranked_card_retrieval"
69
-
70
-
71
- graph.add_conditional_edges(
72
- "neo4j_retriever",
73
- route_after_neo4j_retriever,
74
- {
75
- "neo4j_error_handler": "neo4j_error_handler",
76
- "ranked_card_retrieval": "ranked_card_retrieval",
77
- },
78
- )
79
-
80
- graph.add_edge("neo4j_error_handler", END)
81
- graph.add_edge("ranked_card_retrieval", "agent")
82
-
83
- graph.add_edge("agent", "format_output")
84
- graph.add_edge("format_output",END)
85
-
86
- app = graph.compile()
87
-
88
- # invoking function
89
- async def run_langgraph_pipeline(
90
- query: str,
91
- preferences: str,
92
- query_intent: bool,
93
- include_cobranded: bool,
94
- use_eligibility: bool = False,
95
- age=None,
96
- income=None,
97
- cibil=None,
98
- min_joining_fee=None,
99
- max_joining_fee=None,
100
- min_annual_fee=None,
101
- max_annual_fee=None
102
- ):
103
- debug_print("PIPELINE", f"Starting pipeline with query: '{query}'")
104
- debug_print("PIPELINE", f"Preferences: '{preferences}'")
105
- debug_print("PIPELINE", f"Query intent: {query_intent}, Include cobranded: {include_cobranded}")
106
- debug_print("PIPELINE", f"Eligibility: {use_eligibility}, Age: {age}, Income: {income}, CIBIL: {cibil}")
107
- debug_print("PIPELINE", f"Join fee: {min_joining_fee}-{max_joining_fee}, Annual fee: {min_annual_fee}-{max_annual_fee}")
108
-
109
- inputs = {
110
- "query": query,
111
- "preferences": preferences,
112
- "query_intent": query_intent,
113
- "include_cobranded": include_cobranded,
114
- "use_eligibility": use_eligibility,
115
- "age": age,
116
- "income": income,
117
- "cibil": cibil,
118
- "min_joining_fee": min_joining_fee,
119
- "max_joining_fee": max_joining_fee,
120
- "min_annual_fee": min_annual_fee,
121
- "max_annual_fee": max_annual_fee,
122
- "agent_outcome": None,
123
- "messages": [],
124
- "trigger_chat": False,
125
- "trigger_compare": False,
126
- "selected_cards": [],
127
- "user_message": "",
128
- }
129
-
130
- debug_print("PIPELINE", f"Invoking LangGraph app")
131
- result = await app.ainvoke(inputs)
132
- debug_print("PIPELINE", f"LangGraph execution complete")
133
- card_lookup = result.get("card_lookup", {})
134
- for name, desc in card_lookup.items():
135
- debug_print("PIPELINE_CARD_LOOKUP", f"{name} -> Description length: {len(desc) if isinstance(desc, str) else 'N/A'}")
136
-
137
-
138
- debug_print("PIPELINE", f"Pipeline complete, returning results")
139
- return (
140
- result.get("top_card", "No top card found"),
141
- result.get("top_card_description", []),
142
- result.get("card_rows", []),
143
- result.get("card_names", []),
144
- result.get("card_lookup", {}),
145
- result.get("card_links", [])
146
- )
147
-
148
- #utility graph for chat and compare features
149
-
150
- def passthrough_node(state: CreditCardState) -> CreditCardState:
151
- return state
152
-
153
- def utility_router(state: CreditCardState):
154
- if state.get("trigger_compare", False):
155
- return "compare_node"
156
- elif state.get("trigger_chat", False):
157
- return "chat_agent"
158
- else:
159
- raise ValueError("No trigger flag set for utility graph.")
160
-
161
- def should_call_tool(state: CreditCardState):
162
- if state['router_decision'].decision == "call_tool":
163
- return "call_tool"
164
- else:
165
- return "answer_question"
166
-
167
- utility_graph = StateGraph(CreditCardState)
168
-
169
-
170
- utility_graph.add_node("router", passthrough_node)
171
- utility_graph.add_node("compare_node", compare_node_fn)
172
- utility_graph.add_node("chat_router", router_node)
173
- utility_graph.add_node("call_tool", tool_node)
174
- utility_graph.add_node("answer_question", expert_agent_node)
175
-
176
- utility_graph.set_entry_point("router")
177
-
178
- utility_graph.add_conditional_edges(
179
- "router",
180
- utility_router,
181
- {
182
- "compare_node": "compare_node",
183
- "chat_agent": "chat_router",
184
- },
185
- )
186
-
187
- utility_graph.add_conditional_edges(
188
- "chat_router",
189
- should_call_tool,
190
- {
191
- "call_tool": "call_tool",
192
- "answer_question": "answer_question",
193
- }
194
- )
195
-
196
- utility_graph.add_edge("call_tool", "answer_question")
197
- utility_graph.add_edge("answer_question", END)
198
- utility_graph.add_edge("compare_node", END)
199
-
200
- utility_app = utility_graph.compile()
 
1
+ from langgraph.graph import StateGraph, END
2
+ from data import debug_print
3
+ from nodes.agent import agent_node,TOOLS
4
+ from nodes.intent import oos_handler_node,general_info_handler_node,intent_classifier_node,CreditCardState
5
+ from nodes.format import format_output_node
6
+ from nodes.compare import compare_node_fn
7
+ from nodes.chat import router_node,tool_node,expert_agent_node
8
+ from recommender.graph_retrieval import neo4j_error_handler_node,neo4j_retrieval_node
9
+ from recommender.vectordb import query_refiner_node
10
+ from recommender.vectordb_retrieval import ranked_card_retrieval_node
11
+
12
+
13
+ # Main Graph flow
14
+ graph = StateGraph(CreditCardState)
15
+ graph.add_node("intent_classifier", intent_classifier_node)
16
+ graph.add_node("general_info_handler", general_info_handler_node)
17
+ graph.add_node("oos_handler", oos_handler_node)
18
+ graph.add_node("query_refiner", query_refiner_node)
19
+ graph.add_node("neo4j_retriever", neo4j_retrieval_node)
20
+ graph.add_node("neo4j_error_handler", neo4j_error_handler_node)
21
+ graph.add_node("ranked_card_retrieval", ranked_card_retrieval_node)
22
+ graph.add_node("agent", agent_node)
23
+ graph.add_node("format_output", format_output_node)
24
+
25
+ graph.set_entry_point("intent_classifier")
26
+
27
+ def route_after_intent_classification(state: CreditCardState):
28
+ intent = state["intent"]
29
+ debug_print("ROUTE", f"Intent classification routing with intent: '{intent}'")
30
+
31
+ if intent == "credit-card-recommendation":
32
+ return "query_refiner"
33
+ elif intent == "general-credit-related":
34
+ return "general_info_handler"
35
+ else:
36
+ return "oos_handler"
37
+
38
+ def route_after_format_output(state: CreditCardState):
39
+ if state.get("trigger_compare", False):
40
+ return "compare_node"
41
+ elif state.get("trigger_chat", False):
42
+ return "chat_node"
43
+ else:
44
+ return END
45
+
46
+ graph.add_conditional_edges(
47
+ "intent_classifier",
48
+ route_after_intent_classification,
49
+ {
50
+ "query_refiner": "query_refiner",
51
+ "general_info_handler": "general_info_handler",
52
+ "oos_handler": "oos_handler",
53
+ },
54
+ )
55
+
56
+ graph.add_edge("general_info_handler", END)
57
+ graph.add_edge("oos_handler", END)
58
+ graph.add_edge("query_refiner", "neo4j_retriever")
59
+
60
+ def route_after_neo4j_retriever(state: CreditCardState):
61
+ debug_print("ROUTE", f"neo4j_error: {state.get('neo4j_error')}")
62
+ if state.get("neo4j_error", False):
63
+ return "neo4j_error_handler"
64
+ else:
65
+ return "ranked_card_retrieval"
66
+
67
+
68
+ graph.add_conditional_edges(
69
+ "neo4j_retriever",
70
+ route_after_neo4j_retriever,
71
+ {
72
+ "neo4j_error_handler": "neo4j_error_handler",
73
+ "ranked_card_retrieval": "ranked_card_retrieval",
74
+ },
75
+ )
76
+
77
+ graph.add_edge("neo4j_error_handler", END)
78
+ graph.add_edge("ranked_card_retrieval", "agent")
79
+
80
+ graph.add_edge("agent", "format_output")
81
+ graph.add_edge("format_output",END)
82
+
83
+ app = graph.compile()
84
+
85
+ # invoking function
86
+ async def run_langgraph_pipeline(
87
+ query: str,
88
+ preferences: str,
89
+ query_intent: bool,
90
+ include_cobranded: bool,
91
+ use_eligibility: bool = False,
92
+ age=None,
93
+ income=None,
94
+ cibil=None,
95
+ min_joining_fee=None,
96
+ max_joining_fee=None,
97
+ min_annual_fee=None,
98
+ max_annual_fee=None
99
+ ):
100
+ debug_print("PIPELINE", f"Starting pipeline with query: '{query}'")
101
+ debug_print("PIPELINE", f"Preferences: '{preferences}'")
102
+ debug_print("PIPELINE", f"Query intent: {query_intent}, Include cobranded: {include_cobranded}")
103
+ debug_print("PIPELINE", f"Eligibility: {use_eligibility}, Age: {age}, Income: {income}, CIBIL: {cibil}")
104
+ debug_print("PIPELINE", f"Join fee: {min_joining_fee}-{max_joining_fee}, Annual fee: {min_annual_fee}-{max_annual_fee}")
105
+
106
+ inputs = {
107
+ "query": query,
108
+ "preferences": preferences,
109
+ "query_intent": query_intent,
110
+ "include_cobranded": include_cobranded,
111
+ "use_eligibility": use_eligibility,
112
+ "age": age,
113
+ "income": income,
114
+ "cibil": cibil,
115
+ "min_joining_fee": min_joining_fee,
116
+ "max_joining_fee": max_joining_fee,
117
+ "min_annual_fee": min_annual_fee,
118
+ "max_annual_fee": max_annual_fee,
119
+ "agent_outcome": None,
120
+ "messages": [],
121
+ "trigger_chat": False,
122
+ "trigger_compare": False,
123
+ "selected_cards": [],
124
+ "user_message": "",
125
+ }
126
+
127
+ debug_print("PIPELINE", f"Invoking LangGraph app")
128
+ result = await app.ainvoke(inputs)
129
+ debug_print("PIPELINE", f"LangGraph execution complete")
130
+ card_lookup = result.get("card_lookup", {})
131
+ for name, desc in card_lookup.items():
132
+ debug_print("PIPELINE_CARD_LOOKUP", f"{name} -> Description length: {len(desc) if isinstance(desc, str) else 'N/A'}")
133
+
134
+
135
+ debug_print("PIPELINE", f"Pipeline complete, returning results")
136
+ return (
137
+ result.get("top_card", "No top card found"),
138
+ result.get("top_card_description", []),
139
+ result.get("card_rows", []),
140
+ result.get("card_names", []),
141
+ result.get("card_lookup", {}),
142
+ result.get("card_links", [])
143
+ )
144
+
145
+ #utility graph for chat and compare features
146
+
147
+ def passthrough_node(state: CreditCardState) -> CreditCardState:
148
+ return state
149
+
150
+ def utility_router(state: CreditCardState):
151
+ if state.get("trigger_compare", False):
152
+ return "compare_node"
153
+ elif state.get("trigger_chat", False):
154
+ return "chat_agent"
155
+ else:
156
+ raise ValueError("No trigger flag set for utility graph.")
157
+
158
+ def should_call_tool(state: CreditCardState):
159
+ if state['router_decision'].decision == "call_tool":
160
+ return "call_tool"
161
+ else:
162
+ return "answer_question"
163
+
164
+ utility_graph = StateGraph(CreditCardState)
165
+
166
+
167
+ utility_graph.add_node("router", passthrough_node)
168
+ utility_graph.add_node("compare_node", compare_node_fn)
169
+ utility_graph.add_node("chat_router", router_node)
170
+ utility_graph.add_node("call_tool", tool_node)
171
+ utility_graph.add_node("answer_question", expert_agent_node)
172
+
173
+ utility_graph.set_entry_point("router")
174
+
175
+ utility_graph.add_conditional_edges(
176
+ "router",
177
+ utility_router,
178
+ {
179
+ "compare_node": "compare_node",
180
+ "chat_agent": "chat_router",
181
+ },
182
+ )
183
+
184
+ utility_graph.add_conditional_edges(
185
+ "chat_router",
186
+ should_call_tool,
187
+ {
188
+ "call_tool": "call_tool",
189
+ "answer_question": "answer_question",
190
+ }
191
+ )
192
+
193
+ utility_graph.add_edge("call_tool", "answer_question")
194
+ utility_graph.add_edge("answer_question", END)
195
+ utility_graph.add_edge("compare_node", END)
196
+
197
+ utility_app = utility_graph.compile()