from dotenv import load_dotenv from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.utilities import SerpAPIWrapper from langchain_community.document_loaders import WikipediaLoader from langchain_community.document_loaders import ArxivLoader from typing import TypedDict, Annotated from langchain_core.messages import AnyMessage from langgraph.graph.message import add_messages from langchain_core.messages import HumanMessage, SystemMessage from langgraph.graph import START, StateGraph from langgraph.prebuilt import ToolNode, tools_condition from IPython.display import Image, display from langchain_core.messages import AIMessage from langchain_community.vectorstores import SupabaseVectorStore from supabase.client import Client, create_client import os from langchain_google_genai import GoogleGenerativeAIEmbeddings load_dotenv('../config.env') llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash") embedding_model = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004") supabase_url = os.environ.get("SUPABASE_URL") supabase_key = os.environ.get("SUPABASE_SERVICE_KEY") supabase: Client = create_client(supabase_url, supabase_key) vector_store = SupabaseVectorStore( client=supabase, embedding= embedding_model, table_name="documents", query_name="match_documents_langchain", ) # load the system prompt from the file with open('system_prompt.txt', 'r') as f: system_prompt = f.read() # print(system_prompt) # --Agent tools-- # Calculation tools def add(a: int, b: int) -> int: """ Add two numbers Args: a: first int b: second int """ return a + b def subtract(a: int, b: int) -> int: """ Subtract two numbers Args: a: first int b: second int """ return a - b def multiply(a: int, b: int) -> int: """ Multiply two numbers Args: a: first int b: second int """ return a * b def modulus(a: int, b: int) -> int: """ Get the modulus (remainder) of two numbers Args: a: first int b: second int """ return a % b def divide(a: int, b: int) -> float: """ Divide two numbers Args: a: first int b: second int Returns: The division result as a float """ if b == 0: raise ValueError("Cannot divide by zero") return a / b # Search tools def web_search(query: str) -> str: """ Searches the web using a query string. Useful for answering current events or fact-based questions.", Args: query: string representing the search term. Returns: A string containing top search results. """ search = SerpAPIWrapper() result = search.run(query) return result def wiki_search(query: str) -> str: """ Search Wikipedia for general knowledge. Args: query: Wikipedia search term. Returns: A dict with "wiki_results" containing search results. """ search_docs = WikipediaLoader(query=query,load_max_docs=2).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return {"wiki_results": formatted_search_docs} def arxiv_search(query: str) -> str: """ Searches academic papers on arXiv based on a query. Args: query: The search term to query arXiv. Returns: A string of the top retrieved papers. """ docs = ArxivLoader(query=query, max_results=2).load() return "\n\n---\n\n".join( f"Title: {doc.metadata.get('title', 'N/A')}\nContent: {doc.page_content}" for doc in docs ) tools = [ add, subtract, multiply, divide, modulus, web_search, wiki_search, ] llm_with_tools = llm.bind_tools(tools=tools) def build_graph(): class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] def assistant(state: AgentState): # System message sys_msg = SystemMessage(content=system_prompt) return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} def retriever(state: AgentState): query = state["messages"][-1].content results = vector_store.similarity_search(query, k=1) if not results: # If no documents are found, provide a fallback response. answer = "I couldn't find anything relevant in the knowledge base. Please try rephrasing your question." else: similar_doc = results[0] content = similar_doc.page_content if "Final answer :" in content: answer = content.split("Final answer :")[-1].strip() else: answer = content.strip() return {"messages": [AIMessage(content=answer)]} # Graph builder = StateGraph(AgentState) # Define nodes: these do the work # builder.add_node("assistant", assistant) # builder.add_node("tools", ToolNode(tools)) # # Define edges: these determine how the control flow moves # builder.add_edge(START, "assistant") # builder.add_conditional_edges( # "assistant", # # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools # # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END # tools_condition, # ) # builder.add_edge("tools", "assistant") builder.add_node("retriever", retriever) # Define edges: these determine how the control flow moves builder.add_edge(START, "retriever") builder.set_finish_point("retriever") react_graph = builder.compile() # Show # display(Image(react_graph.get_graph(xray=True).draw_mermaid_png())) return react_graph # test if __name__ == "__main__": react_graph = build_graph() # Calc test print("----Calculation tools test----") question = "Calculate the result of 1+2*3+5 and multiply by 2" messages = [HumanMessage(content=question)] messages = react_graph.invoke({"messages": messages}) for m in messages['messages']: m.pretty_print() # Web search test print("----Web search tools test----") real_question = 'In April of 1977, who was the Prime Minister of the first place mentioned by name in the Book of Esther (in the New International Version)?' messages = [HumanMessage(content=real_question)] messages = react_graph.invoke({"messages": messages}) for m in messages['messages']: m.pretty_print()