"""LangGraph Agent with Hugging Face LLM and Robust Retriever""" import os from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langchain_community.document_loaders import ArxivLoader from langchain_community.vectorstores import SupabaseVectorStore from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.tools import tool from supabase.client import Client, create_client # Load environment variables from .env file load_dotenv() # Define mathematical tools for basic operations @tool def multiply(a: int, b: int) -> int: """Multiply two numbers. Args: a: First integer b: Second integer Returns: Product of a and b """ return a * b @tool def add(a: int, b: int) -> int: """Add two numbers. Args: a: First integer b: Second integer Returns: Sum of a and b """ return a + b @tool def subtract(a: int, b: int) -> int: """Subtract two numbers. Args: a: First integer b: Second integer Returns: Difference of a and b """ return a - b @tool def divide(a: int, b: int) -> int: """Divide two numbers. Args: a: First integer b: Second integer Returns: Quotient of a divided by b Raises: ValueError: If b is zero """ if b == 0: raise ValueError("Cannot divide by zero.") return a // b # Integer division for consistency @tool def modulus(a: int, b: int) -> int: """Get the modulus of two numbers. Args: a: First integer b: Second integer Returns: Remainder of a divided by b """ return a % b # Define search tools for external information retrieval @tool def wiki_search(query: str) -> dict: """Search Wikipedia for a query and return up to 2 results. Args: query: The search query Returns: Dictionary with formatted Wikipedia results """ search_docs = WikipediaLoader(query=query, load_max_docs=2).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ] ) return {"wiki_results": formatted_search_docs} @tool def web_search(query: str) -> dict: """Search Tavily for a query and return up to 3 results. Args: query: The search query Returns: Dictionary with formatted web search results """ search_docs = TavilySearchResults(max_results=3).invoke(query=query) formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc["content"]}\n' for doc in search_docs ] ) return {"web_results": formatted_search_docs} @tool def arxiv_search(query: str) -> dict: """Search Arxiv for a query and return up to 3 results. Args: query: The search query Returns: Dictionary with formatted Arxiv results """ search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ] ) return {"arxiv_results": formatted_search_docs} # Load system prompt from file with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() # Create system message for the LLM sys_msg = SystemMessage(content=system_prompt) # Initialize embeddings for vector store embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # Initialize Supabase client and vector store supabase: Client = create_client( os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY") ) vector_store = SupabaseVectorStore( client=supabase, embedding=embeddings, table_name="documents", query_name="match_documents_langchain" ) # Define tools list tools = [ multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search ] def build_graph(provider: str = "huggingface"): """Build the LangGraph workflow for the agent. Args: provider: The LLM provider to use ('huggingface' by default) Returns: Compiled LangGraph workflow """ # Load environment variables load_dotenv() # Initialize LLM based on provider if provider == "huggingface": llm = ChatHuggingFace( llm=HuggingFaceEndpoint( repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), temperature=0.1, # Low temperature for deterministic responses max_new_tokens=512, # Limit response length timeout=60 # Set timeout for API calls ) ) else: raise ValueError("Only 'huggingface' provider is supported.") # Bind tools to LLM for tool invocation llm_with_tools = llm.bind_tools(tools) # Define assistant node to process queries with LLM def assistant(state: MessagesState): """Assistant node to generate responses using the LLM. Args: state: Current state with messages Returns: Updated state with LLM response """ return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} # Define retriever node to fetch similar documents def retriever(state: MessagesState): """Retriever node to search vector store for similar questions. Args: state: Current state with messages Returns: Updated state with retrieved answer or fallback message """ query = state["messages"][-1].content results = vector_store.similarity_search(query, k=1) if not results: return {"messages": [AIMessage(content="No relevant information found in the vector store. Relying on LLM and tools.")] + state["messages"]} similar_doc = results[0] content = similar_doc.page_content if "Final answer :" in content: answer = content.split("Final answer :")[-1].strip() else: answer = content.strip() return {"messages": [AIMessage(content=answer)] + state["messages"]} # Initialize graph builder = StateGraph(MessagesState) # Add nodes builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) # Define edges builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges( "assistant", tools_condition, # Route to tools if needed ) builder.add_edge("tools", "assistant") # Compile and return graph return builder.compile()