RamiNuraliyev's picture
Upload agent.py
8885680 verified
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities import SerpAPIWrapper
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from typing import TypedDict, Annotated
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from IPython.display import Image, display
from langchain_core.messages import AIMessage
from langchain_community.vectorstores import SupabaseVectorStore
from supabase.client import Client, create_client
import os
from langchain_google_genai import GoogleGenerativeAIEmbeddings
load_dotenv('../config.env')
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
embedding_model = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
supabase: Client = create_client(supabase_url, supabase_key)
vector_store = SupabaseVectorStore(
client=supabase,
embedding= embedding_model,
table_name="documents",
query_name="match_documents_langchain",
)
# load the system prompt from the file
with open('system_prompt.txt', 'r') as f:
system_prompt = f.read()
# print(system_prompt)
# --Agent tools--
# Calculation tools
def add(a: int, b: int) -> int:
"""
Add two numbers
Args:
a: first int
b: second int
"""
return a + b
def subtract(a: int, b: int) -> int:
"""
Subtract two numbers
Args:
a: first int
b: second int
"""
return a - b
def multiply(a: int, b: int) -> int:
"""
Multiply two numbers
Args:
a: first int
b: second int
"""
return a * b
def modulus(a: int, b: int) -> int:
"""
Get the modulus (remainder) of two numbers
Args:
a: first int
b: second int
"""
return a % b
def divide(a: int, b: int) -> float:
"""
Divide two numbers
Args:
a: first int
b: second int
Returns:
The division result as a float
"""
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
# Search tools
def web_search(query: str) -> str:
"""
Searches the web using a query string. Useful for answering current events or fact-based questions.",
Args:
query: string representing the search term.
Returns:
A string containing top search results.
"""
search = SerpAPIWrapper()
result = search.run(query)
return result
def wiki_search(query: str) -> str:
"""
Search Wikipedia for general knowledge.
Args:
query: Wikipedia search term.
Returns:
A dict with "wiki_results" containing search results.
"""
search_docs = WikipediaLoader(query=query,load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"wiki_results": formatted_search_docs}
def arxiv_search(query: str) -> str:
"""
Searches academic papers on arXiv based on a query.
Args:
query: The search term to query arXiv.
Returns:
A string of the top retrieved papers.
"""
docs = ArxivLoader(query=query, max_results=2).load()
return "\n\n---\n\n".join(
f"Title: {doc.metadata.get('title', 'N/A')}\nContent: {doc.page_content}"
for doc in docs
)
tools = [
add,
subtract,
multiply,
divide,
modulus,
web_search,
wiki_search,
]
llm_with_tools = llm.bind_tools(tools=tools)
def build_graph():
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
def assistant(state: AgentState):
# System message
sys_msg = SystemMessage(content=system_prompt)
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
def retriever(state: AgentState):
query = state["messages"][-1].content
results = vector_store.similarity_search(query, k=1)
if not results:
# If no documents are found, provide a fallback response.
answer = "I couldn't find anything relevant in the knowledge base. Please try rephrasing your question."
else:
similar_doc = results[0]
content = similar_doc.page_content
if "Final answer :" in content:
answer = content.split("Final answer :")[-1].strip()
else:
answer = content.strip()
return {"messages": [AIMessage(content=answer)]}
# Graph
builder = StateGraph(AgentState)
# Define nodes: these do the work
# builder.add_node("assistant", assistant)
# builder.add_node("tools", ToolNode(tools))
# # Define edges: these determine how the control flow moves
# builder.add_edge(START, "assistant")
# builder.add_conditional_edges(
# "assistant",
# # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
# # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
# tools_condition,
# )
# builder.add_edge("tools", "assistant")
builder.add_node("retriever", retriever)
# Define edges: these determine how the control flow moves
builder.add_edge(START, "retriever")
builder.set_finish_point("retriever")
react_graph = builder.compile()
# Show
# display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))
return react_graph
# test
if __name__ == "__main__":
react_graph = build_graph()
# Calc test
print("----Calculation tools test----")
question = "Calculate the result of 1+2*3+5 and multiply by 2"
messages = [HumanMessage(content=question)]
messages = react_graph.invoke({"messages": messages})
for m in messages['messages']:
m.pretty_print()
# Web search test
print("----Web search tools test----")
real_question = 'In April of 1977, who was the Prime Minister of the first place mentioned by name in the Book of Esther (in the New International Version)?'
messages = [HumanMessage(content=real_question)]
messages = react_graph.invoke({"messages": messages})
for m in messages['messages']:
m.pretty_print()