|
|
from langgraph.graph import StateGraph, END
|
|
|
from typing import Literal
|
|
|
from data import debug_print
|
|
|
from langchain_core.messages import AIMessage
|
|
|
from langgraph.prebuilt import ToolNode,tools_condition
|
|
|
from nodes.agent import agent_node,TOOLS
|
|
|
from nodes.intent import oos_handler_node,general_info_handler_node,intent_classifier_node,CreditCardState
|
|
|
from nodes.format import format_output_node
|
|
|
from nodes.compare import compare_node_fn
|
|
|
from nodes.chat import router_node,tool_node,expert_agent_node
|
|
|
from recommender.graph_retrieval import neo4j_error_handler_node,neo4j_retrieval_node
|
|
|
from recommender.vectordb import query_refiner_node
|
|
|
from recommender.vectordb_retrieval import ranked_card_retrieval_node
|
|
|
|
|
|
|
|
|
|
|
|
graph = StateGraph(CreditCardState)
|
|
|
graph.add_node("intent_classifier", intent_classifier_node)
|
|
|
graph.add_node("general_info_handler", general_info_handler_node)
|
|
|
graph.add_node("oos_handler", oos_handler_node)
|
|
|
graph.add_node("query_refiner", query_refiner_node)
|
|
|
graph.add_node("neo4j_retriever", neo4j_retrieval_node)
|
|
|
graph.add_node("neo4j_error_handler", neo4j_error_handler_node)
|
|
|
graph.add_node("ranked_card_retrieval", ranked_card_retrieval_node)
|
|
|
graph.add_node("agent", agent_node)
|
|
|
graph.add_node("format_output", format_output_node)
|
|
|
|
|
|
graph.set_entry_point("intent_classifier")
|
|
|
|
|
|
def route_after_intent_classification(state: CreditCardState):
|
|
|
intent = state["intent"]
|
|
|
debug_print("ROUTE", f"Intent classification routing with intent: '{intent}'")
|
|
|
|
|
|
if intent == "credit-card-recommendation":
|
|
|
return "query_refiner"
|
|
|
elif intent == "general-credit-related":
|
|
|
return "general_info_handler"
|
|
|
else:
|
|
|
return "oos_handler"
|
|
|
|
|
|
def route_after_format_output(state: CreditCardState):
|
|
|
if state.get("trigger_compare", False):
|
|
|
return "compare_node"
|
|
|
elif state.get("trigger_chat", False):
|
|
|
return "chat_node"
|
|
|
else:
|
|
|
return END
|
|
|
|
|
|
graph.add_conditional_edges(
|
|
|
"intent_classifier",
|
|
|
route_after_intent_classification,
|
|
|
{
|
|
|
"query_refiner": "query_refiner",
|
|
|
"general_info_handler": "general_info_handler",
|
|
|
"oos_handler": "oos_handler",
|
|
|
},
|
|
|
)
|
|
|
|
|
|
graph.add_edge("general_info_handler", END)
|
|
|
graph.add_edge("oos_handler", END)
|
|
|
graph.add_edge("query_refiner", "neo4j_retriever")
|
|
|
|
|
|
def route_after_neo4j_retriever(state: CreditCardState):
|
|
|
debug_print("ROUTE", f"neo4j_error: {state.get('neo4j_error')}")
|
|
|
if state.get("neo4j_error", False):
|
|
|
return "neo4j_error_handler"
|
|
|
else:
|
|
|
return "ranked_card_retrieval"
|
|
|
|
|
|
|
|
|
graph.add_conditional_edges(
|
|
|
"neo4j_retriever",
|
|
|
route_after_neo4j_retriever,
|
|
|
{
|
|
|
"neo4j_error_handler": "neo4j_error_handler",
|
|
|
"ranked_card_retrieval": "ranked_card_retrieval",
|
|
|
},
|
|
|
)
|
|
|
|
|
|
graph.add_edge("neo4j_error_handler", END)
|
|
|
graph.add_edge("ranked_card_retrieval", "agent")
|
|
|
|
|
|
graph.add_edge("agent", "format_output")
|
|
|
graph.add_edge("format_output",END)
|
|
|
|
|
|
app = graph.compile()
|
|
|
|
|
|
|
|
|
async def run_langgraph_pipeline(
|
|
|
query: str,
|
|
|
preferences: str,
|
|
|
query_intent: bool,
|
|
|
include_cobranded: bool,
|
|
|
use_eligibility: bool = False,
|
|
|
age=None,
|
|
|
income=None,
|
|
|
cibil=None,
|
|
|
min_joining_fee=None,
|
|
|
max_joining_fee=None,
|
|
|
min_annual_fee=None,
|
|
|
max_annual_fee=None
|
|
|
):
|
|
|
debug_print("PIPELINE", f"Starting pipeline with query: '{query}'")
|
|
|
debug_print("PIPELINE", f"Preferences: '{preferences}'")
|
|
|
debug_print("PIPELINE", f"Query intent: {query_intent}, Include cobranded: {include_cobranded}")
|
|
|
debug_print("PIPELINE", f"Eligibility: {use_eligibility}, Age: {age}, Income: {income}, CIBIL: {cibil}")
|
|
|
debug_print("PIPELINE", f"Join fee: {min_joining_fee}-{max_joining_fee}, Annual fee: {min_annual_fee}-{max_annual_fee}")
|
|
|
|
|
|
inputs = {
|
|
|
"query": query,
|
|
|
"preferences": preferences,
|
|
|
"query_intent": query_intent,
|
|
|
"include_cobranded": include_cobranded,
|
|
|
"use_eligibility": use_eligibility,
|
|
|
"age": age,
|
|
|
"income": income,
|
|
|
"cibil": cibil,
|
|
|
"min_joining_fee": min_joining_fee,
|
|
|
"max_joining_fee": max_joining_fee,
|
|
|
"min_annual_fee": min_annual_fee,
|
|
|
"max_annual_fee": max_annual_fee,
|
|
|
"agent_outcome": None,
|
|
|
"messages": [],
|
|
|
"trigger_chat": False,
|
|
|
"trigger_compare": False,
|
|
|
"selected_cards": [],
|
|
|
"user_message": "",
|
|
|
}
|
|
|
|
|
|
debug_print("PIPELINE", f"Invoking LangGraph app")
|
|
|
result = await app.ainvoke(inputs)
|
|
|
debug_print("PIPELINE", f"LangGraph execution complete")
|
|
|
card_lookup = result.get("card_lookup", {})
|
|
|
for name, desc in card_lookup.items():
|
|
|
debug_print("PIPELINE_CARD_LOOKUP", f"{name} -> Description length: {len(desc) if isinstance(desc, str) else 'N/A'}")
|
|
|
|
|
|
|
|
|
debug_print("PIPELINE", f"Pipeline complete, returning results")
|
|
|
return (
|
|
|
result.get("top_card", "No top card found"),
|
|
|
result.get("top_card_description", []),
|
|
|
result.get("card_rows", []),
|
|
|
result.get("card_names", []),
|
|
|
result.get("card_lookup", {}),
|
|
|
result.get("card_links", [])
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def passthrough_node(state: CreditCardState) -> CreditCardState:
|
|
|
return state
|
|
|
|
|
|
def utility_router(state: CreditCardState):
|
|
|
if state.get("trigger_compare", False):
|
|
|
return "compare_node"
|
|
|
elif state.get("trigger_chat", False):
|
|
|
return "chat_agent"
|
|
|
else:
|
|
|
raise ValueError("No trigger flag set for utility graph.")
|
|
|
|
|
|
def should_call_tool(state: CreditCardState):
|
|
|
if state['router_decision'].decision == "call_tool":
|
|
|
return "call_tool"
|
|
|
else:
|
|
|
return "answer_question"
|
|
|
|
|
|
utility_graph = StateGraph(CreditCardState)
|
|
|
|
|
|
|
|
|
utility_graph.add_node("router", passthrough_node)
|
|
|
utility_graph.add_node("compare_node", compare_node_fn)
|
|
|
utility_graph.add_node("chat_router", router_node)
|
|
|
utility_graph.add_node("call_tool", tool_node)
|
|
|
utility_graph.add_node("answer_question", expert_agent_node)
|
|
|
|
|
|
utility_graph.set_entry_point("router")
|
|
|
|
|
|
utility_graph.add_conditional_edges(
|
|
|
"router",
|
|
|
utility_router,
|
|
|
{
|
|
|
"compare_node": "compare_node",
|
|
|
"chat_agent": "chat_router",
|
|
|
},
|
|
|
)
|
|
|
|
|
|
utility_graph.add_conditional_edges(
|
|
|
"chat_router",
|
|
|
should_call_tool,
|
|
|
{
|
|
|
"call_tool": "call_tool",
|
|
|
"answer_question": "answer_question",
|
|
|
}
|
|
|
)
|
|
|
|
|
|
utility_graph.add_edge("call_tool", "answer_question")
|
|
|
utility_graph.add_edge("answer_question", END)
|
|
|
utility_graph.add_edge("compare_node", END)
|
|
|
|
|
|
utility_app = utility_graph.compile()
|
|
|
|