Spaces:
Build error
Build error
| """LangGraph Agent with Hugging Face LLM and Robust Retriever""" | |
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.document_loaders import ArxivLoader | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_core.tools import tool | |
| from supabase.client import Client, create_client | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Define mathematical tools for basic operations | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two numbers. | |
| Args: | |
| a: First integer | |
| b: Second integer | |
| Returns: | |
| Product of a and b | |
| """ | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two numbers. | |
| Args: | |
| a: First integer | |
| b: Second integer | |
| Returns: | |
| Sum of a and b | |
| """ | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract two numbers. | |
| Args: | |
| a: First integer | |
| b: Second integer | |
| Returns: | |
| Difference of a and b | |
| """ | |
| return a - b | |
| def divide(a: int, b: int) -> int: | |
| """Divide two numbers. | |
| Args: | |
| a: First integer | |
| b: Second integer | |
| Returns: | |
| Quotient of a divided by b | |
| Raises: | |
| ValueError: If b is zero | |
| """ | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a // b # Integer division for consistency | |
| def modulus(a: int, b: int) -> int: | |
| """Get the modulus of two numbers. | |
| Args: | |
| a: First integer | |
| b: Second integer | |
| Returns: | |
| Remainder of a divided by b | |
| """ | |
| return a % b | |
| # Define search tools for external information retrieval | |
| def wiki_search(query: str) -> dict: | |
| """Search Wikipedia for a query and return up to 2 results. | |
| Args: | |
| query: The search query | |
| Returns: | |
| Dictionary with formatted Wikipedia results | |
| """ | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| return {"wiki_results": formatted_search_docs} | |
| def web_search(query: str) -> dict: | |
| """Search Tavily for a query and return up to 3 results. | |
| Args: | |
| query: The search query | |
| Returns: | |
| Dictionary with formatted web search results | |
| """ | |
| search_docs = TavilySearchResults(max_results=3).invoke(query=query) | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc["url"]}" title="{doc.get("title", "")}">\n{doc["content"]}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| return {"web_results": formatted_search_docs} | |
| def arxiv_search(query: str) -> dict: | |
| """Search Arxiv for a query and return up to 3 results. | |
| Args: | |
| query: The search query | |
| Returns: | |
| Dictionary with formatted Arxiv results | |
| """ | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| return {"arxiv_results": formatted_search_docs} | |
| # Load system prompt from file | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| # Create system message for the LLM | |
| sys_msg = SystemMessage(content=system_prompt) | |
| # Initialize embeddings for vector store | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| # Initialize Supabase client and vector store | |
| supabase: Client = create_client( | |
| os.environ.get("SUPABASE_URL"), | |
| os.environ.get("SUPABASE_SERVICE_KEY") | |
| ) | |
| vector_store = SupabaseVectorStore( | |
| client=supabase, | |
| embedding=embeddings, | |
| table_name="documents", | |
| query_name="match_documents_langchain" | |
| ) | |
| # Define tools list | |
| tools = [ | |
| multiply, | |
| add, | |
| subtract, | |
| divide, | |
| modulus, | |
| wiki_search, | |
| web_search, | |
| arxiv_search | |
| ] | |
| def build_graph(provider: str = "huggingface"): | |
| """Build the LangGraph workflow for the agent. | |
| Args: | |
| provider: The LLM provider to use ('huggingface' by default) | |
| Returns: | |
| Compiled LangGraph workflow | |
| """ | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize LLM based on provider | |
| if provider == "huggingface": | |
| llm = ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), | |
| temperature=0.1, # Low temperature for deterministic responses | |
| max_new_tokens=512, # Limit response length | |
| timeout=60 # Set timeout for API calls | |
| ) | |
| ) | |
| else: | |
| raise ValueError("Only 'huggingface' provider is supported.") | |
| # Bind tools to LLM for tool invocation | |
| llm_with_tools = llm.bind_tools(tools) | |
| # Define assistant node to process queries with LLM | |
| def assistant(state: MessagesState): | |
| """Assistant node to generate responses using the LLM. | |
| Args: | |
| state: Current state with messages | |
| Returns: | |
| Updated state with LLM response | |
| """ | |
| return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} | |
| # Define retriever node to fetch similar documents | |
| def retriever(state: MessagesState): | |
| """Retriever node to search vector store for similar questions. | |
| Args: | |
| state: Current state with messages | |
| Returns: | |
| Updated state with retrieved answer or fallback message | |
| """ | |
| query = state["messages"][-1].content | |
| results = vector_store.similarity_search(query, k=1) | |
| if not results: | |
| return {"messages": [AIMessage(content="No relevant information found in the vector store. Relying on LLM and tools.")] + state["messages"]} | |
| similar_doc = results[0] | |
| content = similar_doc.page_content | |
| if "Final answer :" in content: | |
| answer = content.split("Final answer :")[-1].strip() | |
| else: | |
| answer = content.strip() | |
| return {"messages": [AIMessage(content=answer)] + state["messages"]} | |
| # Initialize graph | |
| builder = StateGraph(MessagesState) | |
| # Add nodes | |
| builder.add_node("retriever", retriever) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| # Define edges | |
| builder.add_edge(START, "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, # Route to tools if needed | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| # Compile and return graph | |
| return builder.compile() |