Spaces:
Sleeping
Sleeping
| from typing import TypedDict, List | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain.output_parsers import PydanticOutputParser,JsonOutputToolsParser | |
| from pydantic import BaseModel, Field | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from dotenv import load_dotenv | |
| from langchain_core.prompts import PromptTemplate | |
| import os | |
| load_dotenv() | |
| api_key=os.getenv('GEMINI_API_KEY') | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| # ---------- Step 1: Define State ---------- | |
| class AgentState(TypedDict): | |
| query: str | |
| banks: List[str] | |
| category:str | |
| product: str | |
| chat_history:List[str] | |
| search_results: str | |
| risk_analysis:str | |
| final_answer: str | |
| # ---------- Step 2: Define schema for parsing ---------- | |
| class QueryEntities(BaseModel): | |
| banks: List[str] = Field(description="List of banks mentioned") | |
| product: str = Field(description="Financial product (like FD, loan, etc.)") | |
| class RefinedQuery(BaseModel): | |
| refined_query: str = Field(..., description="The improved and precise query") | |
| clarification_needed: bool = Field(..., description="True if query is ambiguous and needs clarification") | |
| clarification_question: str = Field(..., description="The question to ask user if clarification_needed is True") | |
| # ---------- Step 3: LLMs ---------- | |
| llm = llm=ChatGoogleGenerativeAI(model="gemini-1.5-flash",temperature=0.5,api_key=api_key) | |
| # Parser | |
| parser = PydanticOutputParser(pydantic_object=QueryEntities) | |
| entity_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Extract the banks and product from the query."), | |
| ("user", "{query}\n{format_instructions}") | |
| ]).partial(format_instructions=parser.get_format_instructions()) | |
| classify_prompt=ChatPromptTemplate.from_messages([ | |
| ("system","Classify the given query as Banking or Non-Banking.Strictly and if the query is about banking related reply 'Bank' Else 'Non-Bank' "), | |
| ("user","{query}") | |
| ]) | |
| # ---------- Step 4: Define nodes ---------- | |
| def extract_entities(state: AgentState): | |
| parsed = llm.invoke(entity_prompt.format_prompt(query=state["query"])) | |
| entities = parser.parse(parsed.content) | |
| state["banks"] = entities.banks | |
| state["product"] = entities.product | |
| return state | |
| def classify(state:AgentState): | |
| query=state['query'] | |
| msg=classify_prompt.format_messages(query=query) | |
| cat=llm.invoke(msg) | |
| state['cat']=cat.content.strip() | |
| return state | |
| tavily = TavilySearchResults(max_results=10) | |
| def update_history(state:AgentState): | |
| history=state.get("chat_history",[]) | |
| if state.get("query"): | |
| history.append(f"User:{state['query']}") | |
| if state.get("final_answer"): | |
| history.append(f"Assistant:{state['final_answer']}") | |
| state['chat_history']=history | |
| return state | |
| def search_info(state: AgentState): | |
| search_query = state['query'] | |
| results = tavily.invoke({"query": search_query}) | |
| state["search_results"] = str(results) | |
| return state | |
| parser1=PydanticOutputParser(pydantic_object=RefinedQuery) | |
| def refine_query(state: AgentState): | |
| history_text = "\n".join(state.get("chat_history", [])) | |
| prompt = f"""You are a query rewriting agent. | |
| Use the conversations context and latest query to produce more precise and umabiguous query. Ask for clarifying questions if the | |
| query seems ambiguous. | |
| Rules: | |
| Always consider the history of the chat as context . | |
| Do not lose context from previous turns | |
| COnversations so far | |
| {history_text} | |
| latest_query: | |
| {state['query']} | |
| {parser1.get_format_instructions()} | |
| """ | |
| refined = llm.invoke(prompt) | |
| parsed=parser1.parse(refined.content) | |
| state["query"] = parsed.refined_query | |
| return state | |
| def summarize(state: AgentState): | |
| prompt = f"""You are an financial assitant which will help user with their banking queries. | |
| 1. use chat_history and query as context . | |
| 2.Think step by step about the user query | |
| -Main products discussed in the query | |
| -what are the key features and benefits | |
| -Are there any hidden charges or risk involved | |
| 3. Be objective and unbiased | |
| Step 4: Create the final answer in this structure: | |
| - **Summary (2–3 sentences):** High-level explanation of the product. | |
| - **Key Features:** Bullet points of the most important details. | |
| - **Risks/Limitations:** Bullet points of what the user should watch out for. | |
| - **Actionable Insight:** One suggestion or next step the user could take. | |
| chat_history:{state["chat_history"]} | |
| Query: {state['query']} | |
| Results: {state['search_results']}""" | |
| answer = llm.invoke(prompt) | |
| state["final_answer"] = answer.content | |
| state.setdefault("chat_history", []).append(f"User: {state['query']}") | |
| state["chat_history"].append(f"Assistant: {answer.content}") | |
| return state | |
| # ---------- Step 5: Build Graph ---------- | |
| def build_agent(state:AgentState): | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node('classify_query',classify) | |
| workflow.add_node("add_history",update_history) | |
| workflow.add_node("search_info", search_info) | |
| workflow.add_node("query_refinement",refine_query) | |
| workflow.add_node("summarize", summarize) | |
| workflow.set_entry_point("classify_query") | |
| workflow.add_conditional_edges("classify_query", | |
| lambda state:state['cat'], | |
| { | |
| "Bank":"add_history", | |
| "Non-Bank":END | |
| }) | |
| workflow.add_edge("add_history","query_refinement") | |
| workflow.add_edge("query_refinement", "search_info") | |
| workflow.add_edge("search_info", "summarize") | |
| workflow.add_edge("summarize", END) | |
| app = workflow.compile() | |
| return app | |