BAnking_Research_agent / workflow.py
DC0101's picture
Update workflow.py
0adc292 verified
from typing import TypedDict, List
from langchain.memory import ConversationBufferMemory
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser,JsonOutputToolsParser
from pydantic import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
import os
load_dotenv()
api_key=os.getenv('GEMINI_API_KEY')
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# ---------- Step 1: Define State ----------
class AgentState(TypedDict):
query: str
banks: List[str]
category:str
product: str
chat_history:List[str]
search_results: str
risk_analysis:str
final_answer: str
# ---------- Step 2: Define schema for parsing ----------
class QueryEntities(BaseModel):
banks: List[str] = Field(description="List of banks mentioned")
product: str = Field(description="Financial product (like FD, loan, etc.)")
class RefinedQuery(BaseModel):
refined_query: str = Field(..., description="The improved and precise query")
clarification_needed: bool = Field(..., description="True if query is ambiguous and needs clarification")
clarification_question: str = Field(..., description="The question to ask user if clarification_needed is True")
# ---------- Step 3: LLMs ----------
llm = llm=ChatGoogleGenerativeAI(model="gemini-1.5-flash",temperature=0.5,api_key=api_key)
# Parser
parser = PydanticOutputParser(pydantic_object=QueryEntities)
entity_prompt = ChatPromptTemplate.from_messages([
("system", "Extract the banks and product from the query."),
("user", "{query}\n{format_instructions}")
]).partial(format_instructions=parser.get_format_instructions())
classify_prompt=ChatPromptTemplate.from_messages([
("system","Classify the given query as Banking or Non-Banking.Strictly and if the query is about banking related reply 'Bank' Else 'Non-Bank' "),
("user","{query}")
])
# ---------- Step 4: Define nodes ----------
def extract_entities(state: AgentState):
parsed = llm.invoke(entity_prompt.format_prompt(query=state["query"]))
entities = parser.parse(parsed.content)
state["banks"] = entities.banks
state["product"] = entities.product
return state
def classify(state:AgentState):
query=state['query']
msg=classify_prompt.format_messages(query=query)
cat=llm.invoke(msg)
state['cat']=cat.content.strip()
return state
tavily = TavilySearchResults(max_results=10)
def update_history(state:AgentState):
history=state.get("chat_history",[])
if state.get("query"):
history.append(f"User:{state['query']}")
if state.get("final_answer"):
history.append(f"Assistant:{state['final_answer']}")
state['chat_history']=history
return state
def search_info(state: AgentState):
search_query = state['query']
results = tavily.invoke({"query": search_query})
state["search_results"] = str(results)
return state
parser1=PydanticOutputParser(pydantic_object=RefinedQuery)
def refine_query(state: AgentState):
history_text = "\n".join(state.get("chat_history", []))
prompt = f"""You are a query rewriting agent.
Use the conversations context and latest query to produce more precise and umabiguous query. Ask for clarifying questions if the
query seems ambiguous.
Rules:
Always consider the history of the chat as context .
Do not lose context from previous turns
COnversations so far
{history_text}
latest_query:
{state['query']}
{parser1.get_format_instructions()}
"""
refined = llm.invoke(prompt)
parsed=parser1.parse(refined.content)
state["query"] = parsed.refined_query
return state
def summarize(state: AgentState):
prompt = f"""You are an financial assitant which will help user with their banking queries.
1. use chat_history and query as context .
2.Think step by step about the user query
-Main products discussed in the query
-what are the key features and benefits
-Are there any hidden charges or risk involved
3. Be objective and unbiased
Step 4: Create the final answer in this structure:
- **Summary (2–3 sentences):** High-level explanation of the product.
- **Key Features:** Bullet points of the most important details.
- **Risks/Limitations:** Bullet points of what the user should watch out for.
- **Actionable Insight:** One suggestion or next step the user could take.
chat_history:{state["chat_history"]}
Query: {state['query']}
Results: {state['search_results']}"""
answer = llm.invoke(prompt)
state["final_answer"] = answer.content
state.setdefault("chat_history", []).append(f"User: {state['query']}")
state["chat_history"].append(f"Assistant: {answer.content}")
return state
# ---------- Step 5: Build Graph ----------
def build_agent(state:AgentState):
workflow = StateGraph(AgentState)
workflow.add_node('classify_query',classify)
workflow.add_node("add_history",update_history)
workflow.add_node("search_info", search_info)
workflow.add_node("query_refinement",refine_query)
workflow.add_node("summarize", summarize)
workflow.set_entry_point("classify_query")
workflow.add_conditional_edges("classify_query",
lambda state:state['cat'],
{
"Bank":"add_history",
"Non-Bank":END
})
workflow.add_edge("add_history","query_refinement")
workflow.add_edge("query_refinement", "search_info")
workflow.add_edge("search_info", "summarize")
workflow.add_edge("summarize", END)
app = workflow.compile()
return app