mabelwang21 commited on
Commit
4ea88b5
·
1 Parent(s): 88b5dd6

update _assistant_node

Browse files
Files changed (1) hide show
  1. agent.py +42 -21
agent.py CHANGED
@@ -266,20 +266,27 @@ class AgentState(TypedDict):
266
  input_file: Optional[List[str]] # Contains file path (PDF/PNG)
267
  messages: Annotated[List[BaseMessage], add_messages]
268
 
 
269
  # === Agent Class ===
270
  class MyAgent:
271
  def __init__(
272
  self,
273
- model_name: str = "anthropic:claude-3-5-sonnet-latest",
274
  temperature: float = 0.0
275
  ):
276
- # Initialize LLM
277
- self.llm = init_chat_model(model_name, temperature=temperature)
278
- # Base tools
279
- self.tools = tools
280
- # RAG components
281
- self.docs: List[Any] = []
282
- self.retriever: Optional[BM25Retriever] = None
 
 
 
 
 
 
283
 
284
  def add_files(self, file_paths: List[str]):
285
  """
@@ -391,23 +398,25 @@ class MyAgent:
391
  if file_paths:
392
  state["input_file"] = file_paths
393
 
394
- # Build graph
395
  builder = StateGraph(dict)
396
  builder.add_node("assistant", self._assistant_node)
397
  builder.add_node("tools", ToolNode(self.tools))
398
  builder.add_edge(START, "assistant")
399
 
400
- # Always allow the assistant to hand off to the tools node
401
  builder.add_conditional_edges(
402
  "assistant",
403
- lambda s: any(t.name in s["messages"][-1].content for t in self.tools),
404
- {True: "tools", False: "assistant"}
405
  )
406
  builder.add_edge("tools", "assistant")
 
 
407
  graph = builder.compile()
408
-
409
- # Use invoke() instead of run()
410
- out = graph.invoke(state)
411
  last_message = out["messages"][-1].content
412
 
413
  # Extract only the FINAL ANSWER part
@@ -423,19 +432,31 @@ class MyAgent:
423
  def _assistant_node(self, state: dict) -> dict:
424
  """Process messages with the LLM."""
425
  try:
426
- # Check if messages exist
427
- if not state["messages"]:
428
  # Add a system message if empty
429
- state["messages"].append(SystemMessage(content=SYSTEM_PROMPT))
430
-
 
 
 
 
 
 
 
 
 
 
431
  # Invoke the chat model with our BaseMessage list
432
- resp = self.llm(state["messages"])
433
  state["messages"].append(resp)
434
  return state
435
  except Exception as e:
436
- # Handle errors by adding an error message
437
  error_msg = f"Error calling LLM: {str(e)}"
438
  print(error_msg)
 
 
 
439
  return state
440
 
441
 
 
266
  input_file: Optional[List[str]] # Contains file path (PDF/PNG)
267
  messages: Annotated[List[BaseMessage], add_messages]
268
 
269
+
270
  # === Agent Class ===
271
  class MyAgent:
272
  def __init__(
273
  self,
274
+ model_name: str = "anthropic:claude-3-5-sonnet-latest", # <-- Use a valid model name
275
  temperature: float = 0.0
276
  ):
277
+ try:
278
+ self.llm = init_chat_model(
279
+ model_name,
280
+ temperature=temperature
281
+ )
282
+ # Base tools
283
+ self.tools = tools
284
+ # RAG components
285
+ self.docs: List[Any] = []
286
+ self.retriever: Optional[BM25Retriever] = None
287
+ except Exception as e:
288
+ print(f"Error initializing LLM: {e}")
289
+ raise
290
 
291
  def add_files(self, file_paths: List[str]):
292
  """
 
398
  if file_paths:
399
  state["input_file"] = file_paths
400
 
401
+ # Build graph with proper conditional edge to prevent loops
402
  builder = StateGraph(dict)
403
  builder.add_node("assistant", self._assistant_node)
404
  builder.add_node("tools", ToolNode(self.tools))
405
  builder.add_edge(START, "assistant")
406
 
407
+ # Fix conditional edges with better check
408
  builder.add_conditional_edges(
409
  "assistant",
410
+ tools_condition, # Use built-in tools_condition
411
+ "tools"
412
  )
413
  builder.add_edge("tools", "assistant")
414
+
415
+ # Add recursion_limit to prevent infinite loops
416
  graph = builder.compile()
417
+
418
+ # Use invoke() with higher recursion limit
419
+ out = graph.invoke(state, {"recursion_limit": 10}) # Lower limit
420
  last_message = out["messages"][-1].content
421
 
422
  # Extract only the FINAL ANSWER part
 
432
  def _assistant_node(self, state: dict) -> dict:
433
  """Process messages with the LLM."""
434
  try:
435
+ # Check if messages exist and ensure proper format
436
+ if not state.get("messages") or len(state["messages"]) == 0:
437
  # Add a system message if empty
438
+ state["messages"] = [SystemMessage(content=SYSTEM_PROMPT)]
439
+
440
+ # Ensure we have at least a system and user message
441
+ has_system = any(isinstance(m, SystemMessage) for m in state["messages"])
442
+ has_human = any(isinstance(m, HumanMessage) for m in state["messages"])
443
+
444
+ if not has_system:
445
+ state["messages"].insert(0, SystemMessage(content=SYSTEM_PROMPT))
446
+
447
+ if not has_human:
448
+ state["messages"].append(HumanMessage(content="Hello"))
449
+
450
  # Invoke the chat model with our BaseMessage list
451
+ resp = self.llm.invoke(state["messages"])
452
  state["messages"].append(resp)
453
  return state
454
  except Exception as e:
 
455
  error_msg = f"Error calling LLM: {str(e)}"
456
  print(error_msg)
457
+ print(f"Message count: {len(state.get('messages', []))}")
458
+ if state.get("messages"):
459
+ print(f"Message types: {[type(m).__name__ for m in state['messages']]}")
460
  return state
461
 
462