giulia-fontanella's picture
Update agent.py
7eb753f verified
raw
history blame
1.49 kB
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from tools import extract_text
from langchain_community.tools import DuckDuckGoSearchRun
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
class Agent():
def __init__(self, llm):
chat = ChatHuggingFace(llm=llm, verbose=True)
search_tool = DuckDuckGoSearchRun()
vision_llm = ChatOpenAI(model="gpt-4o")
chat_with_tools = chat.bind_tools([extract_text,search_tool])
self._initialize_graph()
def _initialize_graph(self):
builder = StateGraph(AgentState)
# Define nodes
builder.add_node("assistant", self.assistant)
builder.add_node("tools", ToolNode(self.tools))
# Define edges
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant",tools_condition)
builder.add_edge("tools", "assistant")
# Compile the graph
self.agent = builder.compile()
def call_agent(self, messages):
self.agent.invoke({"messages":messages})
def assistant(state: AgentState):
return {
"messages": [chat_with_tools.invoke(state["messages"])],
}