sqfoo commited on
Commit
0df25ad
·
verified ·
1 Parent(s): 907e3f5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +19 -16
agent.py CHANGED
@@ -278,22 +278,22 @@ class BasicAgent:
278
  If Task ID is included in the question, remember to call the relevant read tools [ie. read_file, excel_read, csv_read, mp3_listen, image_caption]
279
  """
280
  self.tools = [duckduck_websearch, serper_websearch, visit_webpage, wiki_search, text_splitter, youtube_transcript, read_file, excel_read, csv_read, mp3_listen, image_caption, python_tool]
281
- self.binded_model = self.model.bind_tools(self.tools)
282
  self.sys_msg = SystemMessage(content=self.sys_prompt)
283
 
284
  self.prompt = ChatPromptTemplate.from_messages([
285
  ("system", self.sys_prompt),
286
  ("human", "{input}")
287
  ])
288
- self.agent = initialize_agent(
289
- tools=self.tools,
290
- llm=self.model,
291
- agent="zero-shot-react-description", # ReAct agent type
292
- verbose=True,
293
- system_prompt=self.prompt,
294
- handle_parsing_errors="Check your output and make sure it conforms, use the Action/Action Input syntax"
295
- )
296
- self.app = self.__graph_compile__()
297
  print("BasicAgent initialized.")
298
 
299
  def __call__(self, task: dict) -> str:
@@ -310,15 +310,19 @@ class BasicAgent:
310
  # fixed_answer = "This is a default answer."
311
  # fixed_answer = self.agent.run(question)
312
  human_message = [HumanMessage(content=question)]
313
- messages = self.app.invoke({"message": human_message})
314
  fixed_answer = messages['messages'][-1].content
315
  print(f"Agent returning fixed answer: {fixed_answer}")
316
  time.sleep(60)
317
  return fixed_answer
318
 
319
  def __graph_compile__(self):
 
 
 
 
320
  builder = StateGraph(MessagesState)
321
- builder.add_node("assistant", self.assistant)
322
  builder.add_node("tools", ToolNode(self.tools))
323
  builder.add_edge(START, "assistant")
324
  builder.add_conditional_edges(
@@ -328,8 +332,7 @@ class BasicAgent:
328
  tools_condition,
329
  )
330
  builder.add_edge("tools", "assistant")
331
- return builder.compile() # Compile graph
 
332
 
333
- def assistant(self, state: MessagesState):
334
- """Assistant Node"""
335
- return {"message": [self.binded_model.invoke([self.sys_msg] + state["message"])]}
 
278
  If Task ID is included in the question, remember to call the relevant read tools [ie. read_file, excel_read, csv_read, mp3_listen, image_caption]
279
  """
280
  self.tools = [duckduck_websearch, serper_websearch, visit_webpage, wiki_search, text_splitter, youtube_transcript, read_file, excel_read, csv_read, mp3_listen, image_caption, python_tool]
281
+ self.model_with_tools = self.model.bind_tools(self.tools)
282
  self.sys_msg = SystemMessage(content=self.sys_prompt)
283
 
284
  self.prompt = ChatPromptTemplate.from_messages([
285
  ("system", self.sys_prompt),
286
  ("human", "{input}")
287
  ])
288
+ # self.agent = initialize_agent(
289
+ # tools=self.tools,
290
+ # llm=self.model,
291
+ # agent="zero-shot-react-description", # ReAct agent type
292
+ # verbose=True,
293
+ # system_prompt=self.prompt,
294
+ # handle_parsing_errors="Check your output and make sure it conforms, use the Action/Action Input syntax"
295
+ # )
296
+ self.graph = self.__graph_compile__()
297
  print("BasicAgent initialized.")
298
 
299
  def __call__(self, task: dict) -> str:
 
310
  # fixed_answer = "This is a default answer."
311
  # fixed_answer = self.agent.run(question)
312
  human_message = [HumanMessage(content=question)]
313
+ messages = self.graph.invoke({"messages": human_message})
314
  fixed_answer = messages['messages'][-1].content
315
  print(f"Agent returning fixed answer: {fixed_answer}")
316
  time.sleep(60)
317
  return fixed_answer
318
 
319
  def __graph_compile__(self):
320
+ def assistant(self, state: MessagesState):
321
+ """Assistant Node"""
322
+ return {"message": [self.model_with_tools.invoke(state["messages"])]}
323
+
324
  builder = StateGraph(MessagesState)
325
+ builder.add_node("assistant", assistant)
326
  builder.add_node("tools", ToolNode(self.tools))
327
  builder.add_edge(START, "assistant")
328
  builder.add_conditional_edges(
 
332
  tools_condition,
333
  )
334
  builder.add_edge("tools", "assistant")
335
+ # Compile graph
336
+ return builder.compile()
337
 
338
+