surya07's picture
Update agent.py
a0be6a7 verified
raw
history blame
11.1 kB
import os
import certifi
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import Chroma
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
load_dotenv()
# ---------------- CONFIGURATION ----------------
# Change this to any valid Hugging Face model endpoint (e.g., meta-llama/Llama-3-8b-chat-hf)
HF_MODEL_NAME = os.getenv("LLAMA_MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
HF_MODEL_URL = f"https://api-inference.huggingface.co/models/{HF_MODEL_NAME}"
# Use the OpenAI-compatible inference endpoint
HF_OPENAI_URL = "https://api-inference.huggingface.co/openai"
# ---------------- UTILITY TOOLS ----------------
@tool
def multiply_numbers(x: int, y: int) -> int:
"""Multiply two integers and return the result."""
return x * y
@tool
def add_numbers(x: int, y: int) -> int:
"""Add two integers and return the sum."""
return x + y
@tool
def subtract_numbers(x: int, y: int) -> int:
"""Subtract the second integer from the first and return the result."""
return x - y
@tool
def divide_numbers(x: int, y: int) -> float:
"""Divide the first number by the second and return the result. Raises an error on division by zero."""
if y == 0:
raise ValueError("Division by zero is not allowed.")
return x / y
@tool
def modulus_numbers(x: int, y: int) -> int:
"""Return the remainder when the first number is divided by the second."""
return x % y
@tool
def power_numbers(base: float, exponent: float) -> float:
"""Raise the base to the power of exponent and return the result."""
return base ** exponent
@tool
def root_number(value: float, n: float) -> float:
"""Compute the nth root of a value and return the result."""
return value ** (1 / n)
@tool
def wiki_lookup(query: str) -> str:
"""Search Wikipedia for the query and return up to 2 summarized documents."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
)
@tool
def web_lookup(query: str) -> str:
"""Search the web using Tavily and return up to 3 summarized results."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
)
@tool
def arxiv_lookup(query: str) -> str:
"""Search arXiv for the query and return summaries of up to 3 papers."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content[:800]}</Document>' for d in docs
)
@tool
def add_numbers(x: int, y: int) -> int:
"""Add two integers and return the sum."""
return x + y
@tool
def subtract_numbers(x: int, y: int) -> int:
"""Subtract the second integer from the first and return the result."""
return x - y
@tool
def divide_numbers(x: int, y: int) -> float:
"""Divide the first number by the second and return the result. Raises an error on division by zero."""
if y == 0:
raise ValueError("Division by zero is not allowed.")
return x / y
@tool
def modulus_numbers(x: int, y: int) -> int:
"""Return the remainder when the first number is divided by the second."""
return x % y
@tool
def power_numbers(base: float, exponent: float) -> float:
"""Raise the base to the power of exponent and return the result."""
return base ** exponent
@tool
def root_number(value: float, n: float) -> float:
"""Compute the nth root of a value and return the result."""
return value ** (1 / n)
@tool
def wiki_lookup(query: str) -> str:
"""Search Wikipedia for the query and return up to 2 summarized documents."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
)
@tool
def web_lookup(query: str) -> str:
"""Search the web using Tavily and return up to 3 summarized results."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
)
@tool
def arxiv_lookup(query: str) -> str:
"""Search arXiv for the query and return summaries of up to 3 papers."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content[:800]}</Document>' for d in docs
)
# # ---------------- SETUP LOCAL VECTORSTORE ----------------
# embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
# sample_docs = [Document(page_content="St. Thomas Aquinas was a medieval Catholic priest and philosopher.", metadata={"source": "wiki", "page": "St. Thomas Aquinas"})]
# split_docs = text_splitter.split_documents(sample_docs)
# vector_db = Chroma.from_documents(documents=split_docs, embedding=embedding_model)
# retriever_tool = create_retriever_tool(
# retriever=vector_db.as_retriever(),
# name="SimilarQuestionFinder",
# description="Retrieve similar questions and examples from vector DB."
# )
# # ---------------- SYSTEM PROMPT ----------------
# with open("system_prompt.txt", "r", encoding="utf-8") as f:
# system_content = f.read()
# system_message = SystemMessage(content=system_content)
# # ---------------- BUILD STATE GRAPH ----------------
# def construct_agent_graph():
# llama_llm = ChatHuggingFace(
# llm=HuggingFaceEndpoint(
# endpoint_url=HF_OPENAI_URL,
# temperature=0
# )
# ).bind_tools([
# multiply_numbers,
# add_numbers,
# subtract_numbers,
# divide_numbers,
# modulus_numbers,
# power_numbers,
# root_number,
# wiki_lookup,
# web_lookup,
# arxiv_lookup,
# retriever_tool,
# ])
# def retrieve_node(state: MessagesState):
# similar = vector_db.similarity_search(state["messages"][0].content)
# hint = HumanMessage(content=f"Reference example:\n{similar[0].page_content}" if similar else "")
# return {"messages": [system_message] + state["messages"] + [hint]}
# def respond_node(state: MessagesState):
# return {"messages": [llama_llm.invoke(state["messages"]) ]}
# graph_builder = StateGraph(MessagesState)
# graph_builder.add_node("find_similar", retrieve_node)
# graph_builder.add_node("generate_answer", respond_node)
# graph_builder.add_node("tool_executor", ToolNode([]))
# graph_builder.add_edge(START, "find_similar")
# graph_builder.add_edge("find_similar", "generate_answer")
# graph_builder.add_conditional_edges(
# "generate_answer",
# tools_condition,
# {"tools": "tool_executor", "default": "generate_answer"}
# )
# graph_builder.add_edge("tool_executor", "generate_answer")
# return graph_builder.compile()
# # ---------------- RUN EXAMPLE ----------------
# if __name__ == "__main__":
# sample_q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
# agent = construct_agent_graph()
# msgs = [HumanMessage(content=sample_q)]
# out = agent.invoke({"messages": msgs})
# for m in out["messages"]:
# m.pretty_print()
# ---------------- EMBEDDINGS & VECTOR DB ----------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
sample_docs = [Document(page_content="Sample doc.", metadata={"source":"wiki"})]
split_docs = text_splitter.split_documents(sample_docs)
vector_db = Chroma.from_documents(documents=split_docs, embedding=embedding_model)
retriever_tool = create_retriever_tool(
retriever=vector_db.as_retriever(),
name="SimilarQuestionFinder",
description="Retrieve similar questions and examples from vector DB."
)
all_tools = [multiply_numbers, add_numbers, subtract_numbers, divide_numbers,
modulus_numbers, power_numbers, root_number,
wiki_lookup, web_lookup, arxiv_lookup, retriever_tool]
# ---------------- SYSTEM PROMPT ----------------
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_content = f.read()
system_message = SystemMessage(content=system_content)
# ---------------- BUILD GRAPH ----------------
def construct_agent_graph():
llama_llm = ChatGroq(
model="qwen-qwq-32b",
api_key=os.environ["GROQ_API_KEY"],
temperature=0,
)
def retrieve_node(state: MessagesState):
msgs = [system_message] + state["messages"]
similar = vector_db.similarity_search(state["messages"][0].content)
if similar:
msgs.append(HumanMessage(content=f"Reference example:\n{similar[0].page_content}"))
return {"messages": msgs}
def respond_node(state: MessagesState):
return {"messages": [llama_llm.invoke(state["messages"])]}
graph = StateGraph(MessagesState)
graph.add_node("find_similar", retrieve_node)
graph.add_node("generate_answer", respond_node)
graph.add_node("tool_executor", ToolNode(tools=all_tools))
graph.add_edge(START, "find_similar")
graph.add_edge("find_similar", "generate_answer")
graph.add_conditional_edges(
"generate_answer",
tools_condition,
{"tools": "tool_executor", "__end__": "__end__"}
)
graph.add_edge("tool_executor", "generate_answer")
return graph.compile()
# ---------------- RUN EXAMPLE ----------------
if __name__ == "__main__":
agent = construct_agent_graph()
sample_q = "When was St. Thomas Aquinas added to that page?"
out = agent.invoke({"messages": [HumanMessage(content=sample_q)]})
for m in out["messages"]:
m.pretty_print()