| import cmath
|
| import os
|
| from typing import Dict, List, Sequence, TypedDict, cast
|
|
|
| from dotenv import load_dotenv
|
| from langchain.tools.retriever import create_retriever_tool
|
| from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
|
| from langchain_community.vectorstores import SupabaseVectorStore
|
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| from langchain_core.tools import tool
|
| from langchain_google_genai import ChatGoogleGenerativeAI
|
| from langchain_groq import ChatGroq
|
| from langchain_huggingface import (
|
| ChatHuggingFace,
|
| HuggingFaceEmbeddings,
|
| HuggingFaceEndpoint,
|
| )
|
| from langchain_tavily import TavilySearch
|
| from langgraph.graph import END, START, MessagesState, StateGraph
|
| from langgraph.prebuilt import ToolNode, tools_condition
|
| from pydantic import BaseModel
|
| from supabase.client import Client, create_client
|
|
|
|
|
| load_dotenv()
|
|
|
|
|
| class WebSearchInput(BaseModel):
|
| query: str
|
|
|
|
|
| class WikipediaSearchInput(BaseModel):
|
| query: str
|
|
|
|
|
| class ArxivSearchInput(BaseModel):
|
| query: str
|
|
|
|
|
| @tool
|
| def search_web(query: str) -> str:
|
| """Search the web using Tavily and return relevant results."""
|
|
|
| """Search Tavily for a query and return maximum 3 results.
|
|
|
| Args:
|
| query: The search query."""
|
| search_docs = TavilySearch(max_results=3).invoke({"query": query})
|
| formatted_search_docs = "\n\n---\n\n".join(
|
| [
|
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| for doc in search_docs
|
| ]
|
| )
|
| return {"web_results": formatted_search_docs}
|
|
|
|
|
| @tool
|
| def search_wikipedia(query: str) -> str:
|
| """Search Wikipedia using LangChain's loader and return the first document summary."""
|
| try:
|
| loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
|
| docs = loader.load()
|
| if not docs:
|
| return {"error": f"No Wikipedia articles found for query: {query}"}
|
| formatted_docs = "\n\n---\n\n".join(
|
| [f"Wikipedia Article: {query}\n\n{doc.page_content}" for doc in docs]
|
| )
|
| return {"wiki_results": formatted_docs}
|
| except Exception as e:
|
| return {"error": f"Error searching Wikipedia: {str(e)}"}
|
|
|
|
|
| @tool
|
| def arxiv_search(query: str) -> str:
|
| """Search Arxiv for a query and return maximum 3 result.
|
| Args:
|
| query: The search query."""
|
| search_docs = ArxivLoader(query=query, load_max_docs=3).load()
|
| formatted_search_docs = "\n\n---\n\n".join(
|
| [
|
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
| for doc in search_docs
|
| ]
|
| )
|
| return {"arxiv_results": formatted_search_docs}
|
|
|
|
|
| @tool
|
| def power(a: float, b: float) -> float:
|
| """
|
| Get the power of two numbers.
|
| Args:
|
| a (float): the first number
|
| b (float): the second number
|
| """
|
| return a**b
|
|
|
|
|
| @tool
|
| def square_root(a: float) -> float | complex:
|
| """
|
| Get the square root of a number.
|
| Args:
|
| a (float): the number to get the square root of
|
| """
|
| if a >= 0:
|
| return a**0.5
|
| return cmath.sqrt(a)
|
|
|
|
|
| @tool
|
| def multiply(a: int, b: int) -> int:
|
| """Multiply two numbers.
|
| Args:
|
| a: first int
|
| b: second int
|
| """
|
| return a * b
|
|
|
|
|
| @tool
|
| def add(a: int, b: int) -> int:
|
| """Add two numbers.
|
| Args:
|
| a: first int
|
| b: second int
|
| """
|
| return a + b
|
|
|
|
|
| @tool
|
| def subtract(a: int, b: int) -> int:
|
| """Subtract two numbers.
|
| Args:
|
| a: first int
|
| b: second int
|
| """
|
| return a - b
|
|
|
|
|
| @tool
|
| def divide(a: float, b: float) -> float:
|
| """
|
| Divides two numbers.
|
| Args:
|
| a (float): the first float number
|
| b (float): the second float number
|
| """
|
| if b == 0:
|
| raise ValueError("Cannot divided by zero.")
|
| return a / b
|
|
|
|
|
| @tool
|
| def modulus(a: int, b: int) -> int:
|
| """Get the modulus of two numbers.
|
| Args:
|
| a: first int
|
| b: second int
|
| """
|
| return a % b
|
|
|
|
|
|
|
| system_prompt = SystemMessage(
|
| content="""You are a helpful assistant tasked with answering questions using a set of tools.
|
| Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
| FINAL ANSWER: [YOUR FINAL ANSWER].
|
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma.
|
| Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
|
| )
|
|
|
| supabase_url = os.environ.get("SUPABASE_URL")
|
| supabase_service_key = os.environ.get("SUPABASE_SERVICE_KEY")
|
|
|
| embeddings = HuggingFaceEmbeddings(
|
| model_name="sentence-transformers/all-mpnet-base-v2"
|
| )
|
| supabase: Client = create_client(supabase_url, supabase_service_key)
|
| vector_store = SupabaseVectorStore(
|
| client=supabase,
|
| embedding=embeddings,
|
| table_name="documents",
|
| query_name="match_documents_langchain",
|
| )
|
| create_retriever_tool = create_retriever_tool(
|
| retriever=vector_store.as_retriever(),
|
| name="Question Search",
|
| description="A tool to retrieve similar questions from a vector store.",
|
| )
|
|
|
|
|
| tools = [
|
| search_wikipedia,
|
| search_web,
|
| arxiv_search,
|
| power,
|
| square_root,
|
| multiply,
|
| divide,
|
| subtract,
|
| add,
|
| modulus,
|
| ]
|
|
|
|
|
| def build_agent_graph(provider: str = "groq"):
|
| """Build the graph"""
|
|
|
|
|
| try:
|
| gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| if provider == "groq":
|
|
|
| chat_model = ChatGroq(
|
| model="qwen-qwq-32b", temperature=0
|
| )
|
| elif provider == "gemini":
|
| chat_model = ChatGoogleGenerativeAI(
|
| model="gemini-2.5-pro",
|
| temperature=1.0,
|
| max_retries=2,
|
| google_api_key=gemini_api_key,
|
| )
|
| elif provider == "huggingface":
|
| llm = HuggingFaceEndpoint(
|
| url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
|
| temperature=0,
|
| )
|
| chat_model = ChatHuggingFace(llm=llm, verbose=True)
|
| else:
|
| raise ValueError("Invalid provider.")
|
| except Exception as e:
|
| raise Exception(f"Failed to initialize LLM: {str(e)}")
|
|
|
| llm_with_tools = chat_model.bind_tools(tools)
|
|
|
|
|
| def assistant(state: MessagesState):
|
| """Assistant node"""
|
| return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
|
|
| def retriever(state: MessagesState):
|
| query = state["messages"][-1].content
|
| results = vector_store.similarity_search(query, k=1)
|
|
|
| if not results:
|
| print(f"[retriever] No similar documents found for query: {query}")
|
| return {
|
| "messages": [
|
| AIMessage(content="I couldn't find any similar content in memory.")
|
| ]
|
| }
|
|
|
| similar_doc = results[0]
|
| content = similar_doc.page_content
|
|
|
| if "Final answer :" in content:
|
| answer = content.split("Final answer :")[-1].strip()
|
| else:
|
| answer = content.strip()
|
|
|
| return {"messages": [AIMessage(content=answer)]}
|
|
|
|
|
| builder = StateGraph(MessagesState)
|
| builder.add_node("retriever", retriever)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| builder.set_entry_point("retriever")
|
| builder.set_finish_point("retriever")
|
|
|
| return builder.compile()
|
|
|
|
|
|
|
| def test_agent():
|
| """Run a manual test of the agent"""
|
| print("\n" + "=" * 50)
|
| print("Starting Agent Test")
|
| print("=" * 50)
|
|
|
|
|
| if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
|
| print("\nError: HUGGINGFACEHUB_API_TOKEN not set")
|
| return
|
| if not os.getenv("GEMINI_API_KEY"):
|
| print("\nError: GEMINI_API_KEY not set")
|
| return
|
| if not os.getenv("TAVILY_API_KEY"):
|
| print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable")
|
|
|
| if not os.getenv("SUPABASE_URL"):
|
| print("\nWarning: SUPABASE_URL not set - web search will be unavailable")
|
|
|
| print("\nInitializing agent...")
|
| try:
|
| graph = build_agent_graph(provider="groq")
|
| print("Agent initialized successfully")
|
| except Exception as e:
|
| print(f"Failed to initialize agent: {str(e)}")
|
| return
|
|
|
|
|
| question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
|
| print("\nTesting question:", question)
|
| print("-" * 50)
|
|
|
| try:
|
|
|
| messages = [HumanMessage(content=question)]
|
|
|
|
|
| print("\nWaiting for response...")
|
| result = graph.invoke({"messages": messages})
|
|
|
|
|
| if result and "messages" in result and result["messages"]:
|
|
|
| answer = result["messages"][-1].content
|
| print("\nResponse received:")
|
| print("-" * 20)
|
| print(answer)
|
| print("-" * 20)
|
| else:
|
| print("\nError: No response from agent")
|
|
|
| except Exception as e:
|
| print(f"\nError processing question: {str(e)}")
|
|
|
| print("\n" + "=" * 50)
|
| print("Test Complete")
|
| print("=" * 50 + "\n")
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| test_agent() |