giulia-fontanella commited on
Commit
c2f85cf
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +45 -0
agent.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph.message import add_messages
2
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
3
+ from langgraph.prebuilt import ToolNode
4
+ from langgraph.graph import START, StateGraph
5
+ from langgraph.prebuilt import tools_condition
6
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
7
+ from tools import *
8
+
9
+
10
+ class AgentState(TypedDict):
11
+ messages: Annotated[list[AnyMessage], add_messages]
12
+
13
+
14
+ class Agent():
15
+ def __init__(self, llm, tools):
16
+ chat = ChatHuggingFace(llm=llm, verbose=True)
17
+ chat_with_tools = chat.bind_tools(tools)
18
+ self._initialize_graph()
19
+
20
+ def _initialize_graph(self):
21
+ builder = StateGraph(AgentState)
22
+
23
+ # Define nodes
24
+ builder.add_node("assistant", self.assistant)
25
+ builder.add_node("tools", ToolNode(self.tools))
26
+
27
+ # Define edges
28
+ builder.add_edge(START, "assistant")
29
+ builder.add_conditional_edges("assistant",tools_condition)
30
+ builder.add_edge("tools", "assistant")
31
+
32
+ # Compile the graph
33
+ self.agent = builder.compile()
34
+
35
+ def call_agent(self, messages):
36
+ self.agent.invoke({"messages":messages})
37
+
38
+ def assistant(state: AgentState):
39
+ return {
40
+ "messages": [chat_with_tools.invoke(state["messages"])],
41
+ }
42
+
43
+
44
+
45
+