giulia-fontanella commited on
Commit
a380329
·
verified ·
1 Parent(s): 9b2bab8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +9 -5
agent.py CHANGED
@@ -12,7 +12,7 @@ class AgentState(TypedDict):
12
  messages: Annotated[list[AnyMessage], add_messages]
13
 
14
 
15
- class Agent():
16
  def __init__(self, llm):
17
  chat = ChatHuggingFace(llm=llm, verbose=True)
18
 
@@ -21,7 +21,9 @@ class Agent():
21
 
22
  chat_with_tools = chat.bind_tools([extract_text,search_tool])
23
  self._initialize_graph()
 
24
 
 
25
  def _initialize_graph(self):
26
  builder = StateGraph(AgentState)
27
 
@@ -37,14 +39,16 @@ class Agent():
37
  # Compile the graph
38
  self.agent = builder.compile()
39
 
40
- def call_agent(self, messages):
41
- self.agent.invoke({"messages":messages})
 
 
 
 
42
 
43
  def assistant(state: AgentState):
44
  return {
45
  "messages": [chat_with_tools.invoke(state["messages"])],
46
  }
47
 
48
-
49
-
50
 
 
12
  messages: Annotated[list[AnyMessage], add_messages]
13
 
14
 
15
+ class BasicAgent():
16
  def __init__(self, llm):
17
  chat = ChatHuggingFace(llm=llm, verbose=True)
18
 
 
21
 
22
  chat_with_tools = chat.bind_tools([extract_text,search_tool])
23
  self._initialize_graph()
24
+ print("BasicAgent initialized.")
25
 
26
+
27
  def _initialize_graph(self):
28
  builder = StateGraph(AgentState)
29
 
 
39
  # Compile the graph
40
  self.agent = builder.compile()
41
 
42
+ def __call__(self, question: str) -> str:
43
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
44
+ response = self.agent.invoke({"messages":question})
45
+ answer = response['messages'][-1].content)
46
+ print(f"Agent returning answer: {answer}")
47
+ return answer
48
 
49
  def assistant(state: AgentState):
50
  return {
51
  "messages": [chat_with_tools.invoke(state["messages"])],
52
  }
53
 
 
 
54