File size: 5,665 Bytes
45e9462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from data import all_card_lookup,eligibility_lookup,debug_print,llm1
from pydantic_schema import RouterResult
from nodes.intent import CreditCardState
from langchain_core.tools import tool
from typing import List
from langgraph.prebuilt import ToolNode
from langchain_core.messages import  SystemMessage, HumanMessage
import pprint

# Chat node with tool
@tool
def fetch_card_details(card_names: List[str]) -> str:
    """Fetches details for a list of credit cards from the database.

    Use this tool to get information about specific credit cards when the user asks about them by name.

    This is useful for comparing cards or getting details about cards not in the initial recommendation.

    Args:

        card_names: A list of exact credit card names to fetch details for.

    Returns:

        A string containing the details of the requested cards.

    """
    debug_print("TOOL", f"fetch_card_details called with cards: {card_names}")
    case_insensitive_desc_lookup = {k.lower(): v for k, v in all_card_lookup.items()}
    case_insensitive_elig_lookup = {k.lower(): v for k, v in eligibility_lookup.items()}
    details = []
    for card_name in card_names:
        lookup_key = card_name.lower()
        
        description = case_insensitive_desc_lookup.get(lookup_key, "Description not found.")
        eligibility = case_insensitive_elig_lookup.get(lookup_key, "No eligibility or fee information available.")
        
        card_details = f"Card: {card_name}\nDescription: {description}\n\nEligibility & Fees:\n{eligibility}"
        details.append(card_details)

    if not details:
        return "Could not find details for the requested cards. Please check the card names."

    return "\n\n---\n\n".join(details)

CHAT_TOOLS = [fetch_card_details]
chat_tool_node = ToolNode(CHAT_TOOLS)

#decides if tool call is required
async def router_node(state: CreditCardState) -> dict:
    debug_print("NODE", "Entering VISIBLE Chain-of-Thought router_node")
    
    known_card_names = state['card_names'] 
    user_query = state['messages'][-1].content
    known_cards_sentence = ", ".join(f"'{name}'" for name in known_card_names)
   
    # Prompt to generate scratchpad/reasoning.
    think_prompt = f"""

        You are an expert routing agent. Your job is to analyze a user's query and write down your reasoning for whether a new credit card needs to be fetched.

        

        **Follow these steps:**

        1.  **Identify:** Find the specific credit card names mentioned in the User Query.

        2.  **Compare:** Check if those names exist in the list of Known Card Names.

        3.  **Conclude:** State your final conclusion about whether a tool is needed. 

        

        **Known Card Names :** We already have information on the following cards: {known_cards_sentence}.

        **User Query:** "{user_query}"

        

        **Your Reasoning Scratchpad:**

        """
            
    debug_print("CoT", "Generating reasoning scratchpad...")
   
    reasoning_response = await llm1.ainvoke([HumanMessage(content=think_prompt)])
    scratchpad = reasoning_response.content
    debug_print("CoT", f"Generated Scratchpad:\n---SCRATCHPAD---\n{scratchpad}\n----------------")

    decide_prompt = f"""

        Based on the following reasoning scratchpad, provide your final decision in the required JSON format.

        **Required Format**

        1.  `decision`: This must be either "call_tool" or "answer_from_context".

        2.  `card_names_to_fetch`: If the decision is "call_tool", this must be a list of the new card names found in the query. Otherwise, it should be null.

                

        **Reasoning Scratchpad:**

        {scratchpad}

        

        **Final JSON Output:**

        """

    json_schema = RouterResult.model_json_schema()
    
    debug_print("CoT", "Generating final decision from scratchpad...")
    final_response = await llm1.ainvoke(
        [HumanMessage(content=decide_prompt)],
        extra_body={"guided_json": json_schema}
    )
    
    router_decision = RouterResult.model_validate_json(final_response.content)
    debug_print("ROUTER", f"Final Decision: {router_decision}")
    
    return {"router_decision": router_decision}

#calls the tool
def tool_node(state: CreditCardState) -> dict:
    debug_print("NODE", "Entering tool_node")
    card_names = state['router_decision'].card_names_to_fetch
    if not card_names: return {"new_card_info": None}
    new_info = fetch_card_details.invoke({"card_names": card_names})
    return {"new_card_info": new_info}

#combines the tool results and generates the final reply
async def expert_agent_node(state: CreditCardState) -> dict:
    print("\n" + "="*60)
    debug_print("EXPERT_AGENT_ENTRY", "Full state entering expert_agent_node:")
    pprint.pprint(state, indent=2)
    print("="*60 + "\n")
    
    system_prompt_from_ui = state['messages'][0].content
    new_card_info = state.get("new_card_info")
    chat_history = state['messages'][1:]

    if new_card_info:
        final_system_prompt = system_prompt_from_ui + "\n\n<Newly_Fetched_Information>\n" + new_card_info + "\n</Newly_Fetched_Information>"
    else:
        final_system_prompt = system_prompt_from_ui
    
    messages_to_send = [SystemMessage(content=final_system_prompt)] + chat_history
    
    print("\n" + "-"*60)
    debug_print("EXPERT_AGENT_PROMPT", "Final prompt being sent to LLM:")
    pprint.pprint(messages_to_send)
    print("-"*60 + "\n")

    response = await llm1.ainvoke(messages_to_send)
    return {"messages": [response]}