File size: 5,972 Bytes
0adc292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import TypedDict, List
from langchain.memory import ConversationBufferMemory
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser,JsonOutputToolsParser
from pydantic import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
import os
load_dotenv()

api_key=os.getenv('GEMINI_API_KEY')
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# ---------- Step 1: Define State ----------
class AgentState(TypedDict):
    query: str
    banks: List[str]
    category:str
    product: str
    chat_history:List[str]
    search_results: str
    risk_analysis:str
    final_answer: str


# ---------- Step 2: Define schema for parsing ----------
class QueryEntities(BaseModel):
    banks: List[str] = Field(description="List of banks mentioned")
    product: str = Field(description="Financial product (like FD, loan, etc.)")

class RefinedQuery(BaseModel):
    refined_query: str = Field(..., description="The improved and precise query")
    clarification_needed: bool = Field(..., description="True if query is ambiguous and needs clarification")
    clarification_question: str = Field(..., description="The question to ask user if clarification_needed is True")



# ---------- Step 3: LLMs ----------
llm = llm=ChatGoogleGenerativeAI(model="gemini-1.5-flash",temperature=0.5,api_key=api_key)

# Parser
parser = PydanticOutputParser(pydantic_object=QueryEntities)

entity_prompt = ChatPromptTemplate.from_messages([
    ("system", "Extract the banks and product from the query."),
    ("user", "{query}\n{format_instructions}")
]).partial(format_instructions=parser.get_format_instructions())
classify_prompt=ChatPromptTemplate.from_messages([
    ("system","Classify the given query as Banking or Non-Banking.Strictly and if the query is about banking related reply 'Bank' Else 'Non-Bank' "),
    ("user","{query}")
])


# ---------- Step 4: Define nodes ----------
def extract_entities(state: AgentState):
    parsed = llm.invoke(entity_prompt.format_prompt(query=state["query"]))
    entities = parser.parse(parsed.content)
    state["banks"] = entities.banks
    state["product"] = entities.product
    return state

def classify(state:AgentState):
    query=state['query']
    msg=classify_prompt.format_messages(query=query)
    cat=llm.invoke(msg)
    state['cat']=cat.content.strip()
    return state

tavily = TavilySearchResults(max_results=10)

def update_history(state:AgentState):
    history=state.get("chat_history",[])

    if state.get("query"):
        history.append(f"User:{state['query']}")
    
    if state.get("final_answer"):
        history.append(f"Assistant:{state['final_answer']}")
    state['chat_history']=history
    return state


def search_info(state: AgentState):
    search_query = state['query']
    results = tavily.invoke({"query": search_query})
    state["search_results"] = str(results)
    return state

parser1=PydanticOutputParser(pydantic_object=RefinedQuery)

def refine_query(state: AgentState):
    history_text = "\n".join(state.get("chat_history", []))
    prompt = f"""You are a query rewriting agent.
    Use the conversations context and latest query to produce more precise and umabiguous query. Ask for clarifying questions if the 
    query seems ambiguous.
    Rules:
    Always consider the history of the chat as context .
    Do not lose context from previous turns 
     COnversations so far 
     {history_text}
    latest_query:
    {state['query']}

    {parser1.get_format_instructions()}
"""
    refined = llm.invoke(prompt)
    parsed=parser1.parse(refined.content)
    state["query"] = parsed.refined_query
    return state

def summarize(state: AgentState):
    prompt = f"""You are an financial assitant which will help user with their banking queries.
    1. use chat_history  and query as context .
    2.Think step by step about the user query  
    -Main products discussed in the query 
    -what are the key features and benefits 
    -Are there any hidden charges or risk involved 
    3. Be objective and unbiased 
    Step 4: Create the final answer in this structure:
    - **Summary (2–3 sentences):** High-level explanation of the product.  
- **Key Features:** Bullet points of the most important details.  
- **Risks/Limitations:** Bullet points of what the user should watch out for.  
- **Actionable Insight:** One suggestion or next step the user could take.

    chat_history:{state["chat_history"]}
    Query: {state['query']}
    Results: {state['search_results']}"""
    answer = llm.invoke(prompt)
    state["final_answer"] = answer.content
    state.setdefault("chat_history", []).append(f"User: {state['query']}")
    state["chat_history"].append(f"Assistant: {answer.content}")
    return state


# ---------- Step 5: Build Graph ----------
def build_agent(state:AgentState):
    workflow = StateGraph(AgentState)
    workflow.add_node('classify_query',classify)
    workflow.add_node("add_history",update_history)
    workflow.add_node("search_info", search_info)
    workflow.add_node("query_refinement",refine_query)
    workflow.add_node("summarize", summarize)

    workflow.set_entry_point("classify_query")
    workflow.add_conditional_edges("classify_query",
                                lambda state:state['cat'],
                                {
                                    "Bank":"add_history",
                                    "Non-Bank":END
                                })
    workflow.add_edge("add_history","query_refinement")
    workflow.add_edge("query_refinement", "search_info")
    workflow.add_edge("search_info", "summarize")
    workflow.add_edge("summarize", END)

    app = workflow.compile()
    return app