Alfred_agent / agent.py
shan gao
change
2976609
import os
import datasets
from huggingface_hub import list_models
""" LangChain / LangGraph imports """
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langchain.tools import Tool
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
# Build retriever
# Load the dataset and make Documents
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
docs = [
Document(
page_content="\n".join(
[
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}",
]
),
metadata={"name": guest["name"]},
)
for guest in guest_dataset
]
# Embeddings & Vectorstore retriever
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
encode_kwargs={"normalize_embeddings": True},
)
vectorstore = FAISS.from_documents(docs, embeddings)
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
# Guest info tool
def extract_text(query: str) -> str:
"""Retrieves detailed information about gala guests based on their name or relation."""
results = retriever.invoke(query)
if results:
return "\n\n".join([doc.page_content for doc in results])
else:
return "No matching guest information found."
guest_info_tool = Tool(
name="guest_info_retriever",
func=extract_text,
description="Retrieves detailed information about gala guests based on their name or relation.",
)
# huggingface hub statistics tool
def get_hub_stats(author: str) -> str:
"""Fetches the most downloaded model from a specific author on the Hugging Face Hub."""
try:
# List models from the specified author, sorted by downloads
models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
if models:
model = models[0]
return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
else:
return f"No models found for author {author}."
except Exception as e:
return f"Error fetching models for {author}: {str(e)}"
# Initialize the tool
hub_stats_tool = Tool(
name="get_hub_stats",
func=get_hub_stats,
description="Fetches the most downloaded model from a specific author on the Hugging Face Hub."
)
# Web search tool
search_tool = DuckDuckGoSearchRun()
tools = [guest_info_tool, hub_stats_tool, search_tool]
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# Build graph function
def build_graph(hf_token: str):
llm = HuggingFaceEndpoint(
repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
huggingfacehub_api_token=hf_token)
chat = ChatHuggingFace(llm=llm, verbose=True)
chat_with_tools = chat.bind_tools(tools)
def assistant(state: AgentState):
# Produce one assistant message (may include a tool call)
return {"messages": [chat_with_tools.invoke(state["messages"])]}
builder = StateGraph(AgentState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
# test
if __name__ == "__main__":
# get API key
api_key = os.getenv('HF_TOKEN')
question = "Who is the president of France?"
graph = build_graph(hf_token=api_key)
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()