Spaces:
Sleeping
Sleeping
File size: 5,972 Bytes
0adc292 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from typing import TypedDict, List
from langchain.memory import ConversationBufferMemory
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser,JsonOutputToolsParser
from pydantic import BaseModel, Field
from langchain_google_genai import ChatGoogleGenerativeAI
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
import os
load_dotenv()
api_key=os.getenv('GEMINI_API_KEY')
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# ---------- Step 1: Define State ----------
class AgentState(TypedDict):
query: str
banks: List[str]
category:str
product: str
chat_history:List[str]
search_results: str
risk_analysis:str
final_answer: str
# ---------- Step 2: Define schema for parsing ----------
class QueryEntities(BaseModel):
banks: List[str] = Field(description="List of banks mentioned")
product: str = Field(description="Financial product (like FD, loan, etc.)")
class RefinedQuery(BaseModel):
refined_query: str = Field(..., description="The improved and precise query")
clarification_needed: bool = Field(..., description="True if query is ambiguous and needs clarification")
clarification_question: str = Field(..., description="The question to ask user if clarification_needed is True")
# ---------- Step 3: LLMs ----------
llm = llm=ChatGoogleGenerativeAI(model="gemini-1.5-flash",temperature=0.5,api_key=api_key)
# Parser
parser = PydanticOutputParser(pydantic_object=QueryEntities)
entity_prompt = ChatPromptTemplate.from_messages([
("system", "Extract the banks and product from the query."),
("user", "{query}\n{format_instructions}")
]).partial(format_instructions=parser.get_format_instructions())
classify_prompt=ChatPromptTemplate.from_messages([
("system","Classify the given query as Banking or Non-Banking.Strictly and if the query is about banking related reply 'Bank' Else 'Non-Bank' "),
("user","{query}")
])
# ---------- Step 4: Define nodes ----------
def extract_entities(state: AgentState):
parsed = llm.invoke(entity_prompt.format_prompt(query=state["query"]))
entities = parser.parse(parsed.content)
state["banks"] = entities.banks
state["product"] = entities.product
return state
def classify(state:AgentState):
query=state['query']
msg=classify_prompt.format_messages(query=query)
cat=llm.invoke(msg)
state['cat']=cat.content.strip()
return state
tavily = TavilySearchResults(max_results=10)
def update_history(state:AgentState):
history=state.get("chat_history",[])
if state.get("query"):
history.append(f"User:{state['query']}")
if state.get("final_answer"):
history.append(f"Assistant:{state['final_answer']}")
state['chat_history']=history
return state
def search_info(state: AgentState):
search_query = state['query']
results = tavily.invoke({"query": search_query})
state["search_results"] = str(results)
return state
parser1=PydanticOutputParser(pydantic_object=RefinedQuery)
def refine_query(state: AgentState):
history_text = "\n".join(state.get("chat_history", []))
prompt = f"""You are a query rewriting agent.
Use the conversations context and latest query to produce more precise and umabiguous query. Ask for clarifying questions if the
query seems ambiguous.
Rules:
Always consider the history of the chat as context .
Do not lose context from previous turns
COnversations so far
{history_text}
latest_query:
{state['query']}
{parser1.get_format_instructions()}
"""
refined = llm.invoke(prompt)
parsed=parser1.parse(refined.content)
state["query"] = parsed.refined_query
return state
def summarize(state: AgentState):
prompt = f"""You are an financial assitant which will help user with their banking queries.
1. use chat_history and query as context .
2.Think step by step about the user query
-Main products discussed in the query
-what are the key features and benefits
-Are there any hidden charges or risk involved
3. Be objective and unbiased
Step 4: Create the final answer in this structure:
- **Summary (2–3 sentences):** High-level explanation of the product.
- **Key Features:** Bullet points of the most important details.
- **Risks/Limitations:** Bullet points of what the user should watch out for.
- **Actionable Insight:** One suggestion or next step the user could take.
chat_history:{state["chat_history"]}
Query: {state['query']}
Results: {state['search_results']}"""
answer = llm.invoke(prompt)
state["final_answer"] = answer.content
state.setdefault("chat_history", []).append(f"User: {state['query']}")
state["chat_history"].append(f"Assistant: {answer.content}")
return state
# ---------- Step 5: Build Graph ----------
def build_agent(state:AgentState):
workflow = StateGraph(AgentState)
workflow.add_node('classify_query',classify)
workflow.add_node("add_history",update_history)
workflow.add_node("search_info", search_info)
workflow.add_node("query_refinement",refine_query)
workflow.add_node("summarize", summarize)
workflow.set_entry_point("classify_query")
workflow.add_conditional_edges("classify_query",
lambda state:state['cat'],
{
"Bank":"add_history",
"Non-Bank":END
})
workflow.add_edge("add_history","query_refinement")
workflow.add_edge("query_refinement", "search_info")
workflow.add_edge("search_info", "summarize")
workflow.add_edge("summarize", END)
app = workflow.compile()
return app
|