| from langgraph.graph.message import add_messages |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage |
| from langgraph.prebuilt import ToolNode |
| from langgraph.graph import START, StateGraph |
| from langgraph.prebuilt import tools_condition |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
| from tools import * |
|
|
|
|
| class AgentState(TypedDict): |
| messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
| class Agent(): |
| def __init__(self, llm, tools): |
| chat = ChatHuggingFace(llm=llm, verbose=True) |
| chat_with_tools = chat.bind_tools(tools) |
| self._initialize_graph() |
|
|
| def _initialize_graph(self): |
| builder = StateGraph(AgentState) |
|
|
| |
| builder.add_node("assistant", self.assistant) |
| builder.add_node("tools", ToolNode(self.tools)) |
|
|
| |
| builder.add_edge(START, "assistant") |
| builder.add_conditional_edges("assistant",tools_condition) |
| builder.add_edge("tools", "assistant") |
|
|
| |
| self.agent = builder.compile() |
|
|
| def call_agent(self, messages): |
| self.agent.invoke({"messages":messages}) |
|
|
| def assistant(state: AgentState): |
| return { |
| "messages": [chat_with_tools.invoke(state["messages"])], |
| } |
|
|
|
|
|
|
| |