Shaukat39's picture
Update agent.py
ce23fec verified
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from supabase.client import create_client
from langchain_core.messages import AIMessage
import re
import traceback
load_dotenv()
# ------------------ Arithmetic Tools ------------------
@tool
def multiply(a: int, b: int) -> str:
"""
Multiply two integers and return the result as a string.
Args:
a (int): The first integer.
b (int): The second integer.
Returns:
str: The product of a and b, as a string.
"""
return str(a * b)
@tool
def add(a: int, b: int) -> str:
"""
Add two integers and return the result as a string.
Args:
a (int): The first integer.
b (int): The second integer.
Returns:
str: The sum of a and b, as a string.
"""
return str(a + b)
@tool
def subtract(a: int, b: int) -> str:
"""
Subtract one integer from another and return the result as a string.
Args:
a (int): The minuend.
b (int): The subtrahend.
Returns:
str: The difference (a - b), as a string.
"""
return str(a - b)
@tool
def divide(a: int, b: int) -> str:
"""
Divide one integer by another and return the result as a string.
Args:
a (int): The numerator.
b (int): The denominator. Must not be zero.
Returns:
str: The result of the division (a / b), as a string. Returns an error message if b is zero.
"""
if b == 0:
return "Error: Cannot divide by zero."
return str(a / b)
@tool
def modulus(a: int, b: int) -> str:
"""
Compute the modulus (remainder) of two integers and return the result as a string.
Args:
a (int): The numerator.
b (int): The denominator.
Returns:
str: The remainder when a is divided by b, as a string.
"""
return str(a % b)
# ------------------ Retrieval Tools ------------------
@tool
def wiki_search(query: str) -> str:
"""
Search Wikipedia for a given query and return text from up to two matching articles.
Args:
query (str): A string query to search on Wikipedia.
Returns:
str: Combined content from up to two relevant articles, separated by dividers.
"""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(doc.page_content for doc in docs)
@tool
def web_search(query: str) -> str:
"""
Perform a web search using Tavily and return content from the top three results.
Args:
query (str): A string representing the web search topic.
Returns:
str: Combined content from up to three top results, separated by dividers.
"""
docs = TavilySearchResults(max_results=3).invoke(query)
return "\n\n---\n\n".join(doc.page_content for doc in docs)
@tool
def arvix_search(query: str) -> str:
"""
Search arXiv for academic papers related to the query and return excerpts.
Args:
query (str): The search query string.
Returns:
str: Excerpts (up to 1000 characters each) from up to three relevant arXiv papers, separated by dividers.
"""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
# ------------------ System Prompt ------------------
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read().strip()
# ------------------ Supabase Setup ------------------
url = os.environ["SUPABASE_URL"].strip()
key = os.environ["SUPABASE_SERVICE_KEY"].strip()
client = create_client(url, key)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# Embed improved QA docs
qa_examples = [
{"content": "Q: What is the capital of Vietnam?\nA: FINAL ANSWER: Hanoi"},
{"content": "Q: Alphabetize: lettuce, broccoli, basil\nA: FINAL ANSWER: basil,broccoli,lettuce"},
{"content": "Q: What is 42 multiplied by 8?\nA: FINAL ANSWER: three hundred thirty six"},
]
vector_store = SupabaseVectorStore(
client=client,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain"
)
vector_store.add_texts([doc["content"] for doc in qa_examples])
print("✅ QA documents embedded into Supabase.")
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="Retrieve similar questions from vector DB."
)
tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
# ------------------ Build Agent Graph ------------------
class VerboseToolNode(ToolNode):
def invoke(self, state):
print("🔧 ToolNode evaluating:", [m.content for m in state["messages"]])
return super().invoke(state)
def build_graph(provider: str = "groq"):
if provider == "google":
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
elif provider == "groq":
llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0.3
)
)
else:
raise ValueError("Invalid provider.")
llm_with_tools = llm.bind_tools(tools)
def retriever(state: MessagesState):
query = state["messages"][0].content
similar = vector_store.similarity_search_with_score(query)
threshold = 0.7
examples = [
HumanMessage(content=f"Similar QA:\n{doc.page_content}")
for doc, score in similar if score >= threshold
]
return {"messages": state["messages"] + examples}
def assistant(state: MessagesState):
try:
messages = [SystemMessage(content=system_prompt.strip())] + state["messages"]
result = llm_with_tools.invoke(messages)
# Handle different return types gracefully
if hasattr(result, "content"):
raw_output = result.content.strip()
elif isinstance(result, dict) and "content" in result:
raw_output = result["content"].strip()
else:
raise ValueError(f"Unexpected result format: {repr(result)}")
print("🤖 Raw LLM output:", repr(raw_output))
match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
if match:
final_output = f"FINAL ANSWER: {match.group(1).strip()}"
else:
print("⚠️ 'FINAL ANSWER:' not found. Raw content will be used as fallback.")
final_output = "FINAL ANSWER: Unable to determine answer" if not raw_output else f"FINAL ANSWER: {raw_output}"
return {"messages": [AIMessage(content=final_output)]}
except Exception as e:
print(f"🔥 Exception: {e}")
traceback.print_exc()
return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {type(e).__name__}: {e}")]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", VerboseToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
# ------------------ Local Test Harness ------------------
if __name__ == "__main__":
graph = build_graph(provider="groq")
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
messages = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
print(result["messages"][-1].content)
# """LangGraph Agent"""
# import os
# from dotenv import load_dotenv
# from langgraph.graph import START, StateGraph, MessagesState
# from langgraph.prebuilt import tools_condition
# from langgraph.prebuilt import ToolNode
# from langchain_google_genai import ChatGoogleGenerativeAI
# from langchain_groq import ChatGroq
# from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
# from langchain_community.tools.tavily_search import TavilySearchResults
# from langchain_community.document_loaders import WikipediaLoader
# from langchain_community.document_loaders import ArxivLoader
# from langchain_community.vectorstores import SupabaseVectorStore
# from langchain_core.messages import SystemMessage, HumanMessage
# from langchain_core.tools import tool
# from langchain.tools.retriever import create_retriever_tool
# from supabase.client import Client, create_client
# load_dotenv()
# @tool
# def multiply(a: int, b: int) -> int:
# """Multiply two numbers.
# Args:
# a: first int
# b: second int
# """
# return a * b
# @tool
# def add(a: int, b: int) -> int:
# """Add two numbers.
# Args:
# a: first int
# b: second int
# """
# return a + b
# @tool
# def subtract(a: int, b: int) -> int:
# """Subtract two numbers.
# Args:
# a: first int
# b: second int
# """
# return a - b
# @tool
# def divide(a: int, b: int) -> int:
# """Divide two numbers.
# Args:
# a: first int
# b: second int
# """
# if b == 0:
# raise ValueError("Cannot divide by zero.")
# return a / b
# @tool
# def modulus(a: int, b: int) -> int:
# """Get the modulus of two numbers.
# Args:
# a: first int
# b: second int
# """
# return a % b
# @tool
# def wiki_search(query: str) -> str:
# """Search Wikipedia for a query and return maximum 2 results.
# Args:
# query: The search query."""
# search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
# formatted_search_docs = "\n\n---\n\n".join(
# [
# f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
# for doc in search_docs
# ])
# return {"wiki_results": formatted_search_docs}
# @tool
# def web_search(query: str) -> str:
# """Search Tavily for a query and return maximum 3 results.
# Args:
# query: The search query."""
# search_docs = TavilySearchResults(max_results=3).invoke(query=query)
# formatted_search_docs = "\n\n---\n\n".join(
# [
# f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
# for doc in search_docs
# ])
# return {"web_results": formatted_search_docs}
# @tool
# def arvix_search(query: str) -> str:
# """Search Arxiv for a query and return maximum 3 result.
# Args:
# query: The search query."""
# search_docs = ArxivLoader(query=query, load_max_docs=3).load()
# formatted_search_docs = "\n\n---\n\n".join(
# [
# f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
# for doc in search_docs
# ])
# return {"arvix_results": formatted_search_docs}
# # load the system prompt from the file
# with open("system_prompt.txt", "r", encoding="utf-8") as f:
# system_prompt = f.read()
# # System message
# sys_msg = SystemMessage(content=system_prompt)
# # build a retriever
# embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
# supabase: Client = create_client(
# os.environ.get("SUPABASE_URL"),
# os.environ.get("SUPABASE_SERVICE_KEY"))
# vector_store = SupabaseVectorStore(
# client=supabase,
# embedding= embeddings,
# table_name="documents",
# query_name="match_documents_langchain",
# )
# create_retriever_tool = create_retriever_tool(
# retriever=vector_store.as_retriever(),
# name="Question Search",
# description="A tool to retrieve similar questions from a vector store.",
# )
# tools = [
# multiply,
# add,
# subtract,
# divide,
# modulus,
# wiki_search,
# web_search,
# arvix_search,
# ]
# # Build graph function
# def build_graph(provider: str = "groq"):
# """Build the graph"""
# # Load environment variables from .env file
# if provider == "google":
# # Google Gemini
# llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
# elif provider == "groq":
# # Groq https://console.groq.com/docs/models
# llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
# elif provider == "huggingface":
# # TODO: Add huggingface endpoint
# llm = ChatHuggingFace(
# llm=HuggingFaceEndpoint(
# url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
# temperature=0,
# ),
# )
# else:
# raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
# # Bind tools to LLM
# llm_with_tools = llm.bind_tools(tools)
# # Node
# def assistant(state: MessagesState):
# """Assistant node"""
# return {"messages": [llm_with_tools.invoke(state["messages"])]}
# def retriever(state: MessagesState):
# """Retriever node"""
# similar_question = vector_store.similarity_search(state["messages"][0].content)
# example_msg = HumanMessage(
# content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
# )
# return {"messages": [sys_msg] + state["messages"] + [example_msg]}
# builder = StateGraph(MessagesState)
# builder.add_node("retriever", retriever)
# builder.add_node("assistant", assistant)
# builder.add_node("tools", ToolNode(tools))
# builder.add_edge(START, "retriever")
# builder.add_edge("retriever", "assistant")
# builder.add_conditional_edges(
# "assistant",
# tools_condition,
# )
# builder.add_edge("tools", "assistant")
# # Compile graph
# return builder.compile()
# # test
# if __name__ == "__main__":
# question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
# # Build the graph
# graph = build_graph(provider="groq")
# # Run the graph
# messages = [HumanMessage(content=question)]
# messages = graph.invoke({"messages": messages})
# for m in messages["messages"]:
# m.pretty_print()