Lasdw commited on
Commit
4a325e1
·
1 Parent(s): 28e126d

Implemented prompt pruning

Browse files
Files changed (1) hide show
  1. agent.py +106 -55
agent.py CHANGED
@@ -757,6 +757,7 @@ Action:
757
  }
758
  ```
759
  Observation: [another tool result will appear here]
 
760
  IMPORTANT: You MUST strictly follow the ReAct pattern (Reasoning, Action, Observation):
761
  1. First reason about the problem in the "Thought" section
762
  2. Then decide what action to take in the "Action" section (using the tools)
@@ -764,11 +765,11 @@ IMPORTANT: You MUST strictly follow the ReAct pattern (Reasoning, Action, Observ
764
  4. Based on the observation, continue with another thought
765
  5. This cycle repeats until you have enough information to provide a final answer
766
 
767
- NEVER fake or simulate tool output yourself.
768
 
769
  ... (this Thought/Action/Observation cycle can repeat as needed) ...
770
  Thought: I now know the final answer
771
- Final Answer: Directly answer the question in the shortest possible way. For example, if the question is "What is the capital of France?", the answer should be "Paris" without any additional text. If the question is "What is the population of New York City?", the answer should be "8.4 million" without any additional text.
772
  Make sure to follow any formatting instructions given by the user.
773
  Now begin! Reminder to ALWAYS use the exact characters `Final Answer:` when you provide a definitive answer."""
774
 
@@ -836,32 +837,106 @@ class AgentState(TypedDict, total=False):
836
  messages: Annotated[list[AnyMessage], add_messages]
837
  current_tool: Optional[str]
838
  action_input: Optional[ActionInput]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
 
840
  def assistant(state: AgentState) -> Dict[str, Any]:
841
  """Assistant node that processes messages and decides on next action."""
842
  print("Assistant Called...\n\n")
843
 
844
- # Always include the system message at the beginning of the messages list
845
- # This ensures the LLM follows the correct ReAct pattern in every call
 
 
 
 
846
  system_msg = SystemMessage(content=SYSTEM_PROMPT)
847
 
848
- # Get user messages from state, but leave out any existing system messages
849
- user_messages = [msg for msg in state["messages"] if not isinstance(msg, SystemMessage)]
 
 
 
 
 
 
 
 
 
 
 
850
 
851
- # Combine system message with user messages
852
- messages = [system_msg] + user_messages
853
 
 
 
 
 
 
854
  # Get response from the assistant
855
- response = chat_with_tools.invoke(messages, stop=["Observation:"])
856
  print(f"Assistant response type: {type(response)}")
857
- print(f"Response content: {response.content}...")
 
 
858
 
859
  # Extract the action JSON from the response text
860
  action_json = extract_json_from_text(response.content)
861
  print(f"Extracted action JSON: {action_json}")
862
 
863
- # Create a copy of the assistant's response to add to the message history
864
- assistant_message = AIMessage(content=response.content)
 
 
 
 
865
 
866
  if action_json and "action" in action_json and "action_input" in action_json:
867
  tool_name = action_json["action"]
@@ -869,31 +944,17 @@ def assistant(state: AgentState) -> Dict[str, Any]:
869
  print(f"Extracted tool: {tool_name}")
870
  print(f"Tool input: {tool_input}")
871
 
872
- # Create a tool call ID for the ToolMessage
873
  tool_call_id = f"call_{random.randint(1000000, 9999999)}"
874
 
875
- # Create state update with the assistant's response included
876
- state_update = {
877
- "messages": state["messages"] + [assistant_message], # Add full assistant response to history
878
- "current_tool": tool_name,
879
- "tool_call_id": tool_call_id
880
- }
881
-
882
- # Add action_input to state
883
- if isinstance(tool_input, dict):
884
- state_update["action_input"] = tool_input
885
-
886
- return state_update
887
-
888
- # No tool calls or end of chain indicated by "Final Answer"
889
- if "Final Answer:" in response.content:
890
- print("Final answer detected")
891
 
892
- return {
893
- "messages": state["messages"] + [assistant_message], # Add full assistant response to history
894
- "current_tool": None,
895
- "action_input": None
896
- }
897
 
898
  def extract_json_from_text(text: str) -> dict:
899
  """Extract JSON from text, handling markdown code blocks."""
@@ -1403,11 +1464,11 @@ def create_agent_graph() -> StateGraph:
1403
 
1404
  # Main agent class that integrates with your existing app.py
1405
  class TurboNerd:
1406
- def __init__(self, max_execution_time=60, apify_api_token=None):
1407
  self.graph = create_agent_graph()
1408
  self.tools = tools_config
1409
- self.max_execution_time = max_execution_time # Maximum execution time in seconds
1410
-
1411
  # Set Apify API token if provided
1412
  if apify_api_token:
1413
  os.environ["APIFY_API_TOKEN"] = apify_api_token
@@ -1419,16 +1480,16 @@ class TurboNerd:
1419
  initial_state = {
1420
  "messages": [HumanMessage(content=f"Question: {question}")],
1421
  "current_tool": None,
1422
- "action_input": None
 
1423
  }
1424
 
1425
- # Run the graph with timeout
1426
  print(f"Starting graph execution with question: {question}")
1427
- start_time = time.time()
1428
 
1429
  try:
1430
- # Set a reasonable recursion limit
1431
- result = self.graph.invoke(initial_state, {"recursion_limit": 100})
1432
 
1433
  # Print the final state for debugging
1434
  print(f"Final state keys: {result.keys()}")
@@ -1439,7 +1500,7 @@ class TurboNerd:
1439
  print("Final message: ", final_message)
1440
  # Extract just the final answer part
1441
  if "Final Answer:" in final_message:
1442
- final_answer = final_message.split("Final Answer:")[1].strip()
1443
  return final_answer
1444
 
1445
  return final_message
@@ -1451,18 +1512,8 @@ class TurboNerd:
1451
 
1452
  # Example usage:
1453
  if __name__ == "__main__":
1454
- agent = TurboNerd(max_execution_time=60)
1455
- response = agent("""Given this table defining * on the set S = {a, b, c, d, e}
1456
-
1457
- |*|a|b|c|d|e|
1458
- |---|---|---|---|---|---|
1459
- |a|a|b|c|b|d|
1460
- |b|b|c|a|e|c|
1461
- |c|c|a|b|b|a|
1462
- |d|b|e|b|e|d|
1463
- |e|d|b|a|d|c|
1464
-
1465
- provide the subset of S involved in any possible counter-examples that prove * is not commutative. Provide your answer as a comma separated list of the elements in the set in alphabetical order.""")
1466
  print("\nFinal Response:")
1467
  print(response)
1468
 
 
757
  }
758
  ```
759
  Observation: [another tool result will appear here]
760
+
761
  IMPORTANT: You MUST strictly follow the ReAct pattern (Reasoning, Action, Observation):
762
  1. First reason about the problem in the "Thought" section
763
  2. Then decide what action to take in the "Action" section (using the tools)
 
765
  4. Based on the observation, continue with another thought
766
  5. This cycle repeats until you have enough information to provide a final answer
767
 
768
+ NEVER fake or simulate tool output yourself. If you are unable to make progreess in a certain way, try a different tool or a different approach.
769
 
770
  ... (this Thought/Action/Observation cycle can repeat as needed) ...
771
  Thought: I now know the final answer
772
+ Final Answer: YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
773
  Make sure to follow any formatting instructions given by the user.
774
  Now begin! Reminder to ALWAYS use the exact characters `Final Answer:` when you provide a definitive answer."""
775
 
 
837
  messages: Annotated[list[AnyMessage], add_messages]
838
  current_tool: Optional[str]
839
  action_input: Optional[ActionInput]
840
+ iteration_count: int # Added to track iterations
841
+ # tool_call_id: Optional[str] # Ensure this is present if used by your graph logic for tools
842
+
843
+ # Add prune_messages_for_llm function
844
+ def prune_messages_for_llm(
845
+ full_history: List[AnyMessage],
846
+ num_recent_to_keep: int = 6 # Keeps roughly 2-3 ReAct turns (Thought/Action, Observation)
847
+ ) -> List[AnyMessage]:
848
+ """
849
+ Prunes the message history for the LLM call.
850
+ This function expects a 'core' history (messages without the initial SystemMessage).
851
+ It keeps the first HumanMessage (original query) and the last `num_recent_to_keep` messages
852
+ from this core history, injecting a condensation note.
853
+ """
854
+ if not full_history: # full_history here is actually core_history
855
+ return []
856
+
857
+ first_human_message: Optional[HumanMessage] = None
858
+ for msg in full_history: # Iterate over the provided core_history
859
+ if isinstance(msg, HumanMessage):
860
+ first_human_message = msg
861
+ break
862
+
863
+ # If history is too short or no initial human query found in core_history,
864
+ # return core_history as is. The calling function (assistant) will prepend SystemMessage.
865
+ # Threshold considers: first_human (1) + condensation_note (1) + num_recent_to_keep
866
+ if first_human_message is None or len(full_history) < (1 + 1 + num_recent_to_keep):
867
+ return full_history
868
+
869
+ # Pruning is needed for the core_history
870
+ recent_messages_from_core = full_history[-num_recent_to_keep:]
871
+
872
+ pruned_core_list: List[AnyMessage] = []
873
+
874
+ # Add the first human message
875
+ pruned_core_list.append(first_human_message)
876
+
877
+ # Add condensation note
878
+ pruned_core_list.append(
879
+ AIMessage(content="[System note: To manage context length, earlier parts of the conversation have been omitted. The original query and the most recent interactions are preserved.]")
880
+ )
881
+
882
+ # Add recent messages, ensuring not to duplicate the first_human_message if it's in the recent slice
883
+ for msg in recent_messages_from_core:
884
+ if msg is not first_human_message: # Check object identity
885
+ pruned_core_list.append(msg)
886
+
887
+ return pruned_core_list
888
 
889
  def assistant(state: AgentState) -> Dict[str, Any]:
890
  """Assistant node that processes messages and decides on next action."""
891
  print("Assistant Called...\n\n")
892
 
893
+ full_current_history = state["messages"]
894
+ iteration_count = state.get("iteration_count", 0)
895
+ iteration_count += 1 # Increment for the current call
896
+ print(f"Current Iteration: {iteration_count}")
897
+
898
+ # Prepare messages for the LLM
899
  system_msg = SystemMessage(content=SYSTEM_PROMPT)
900
 
901
+ # Core history excludes any SystemMessages found in the accumulated history.
902
+ # The pruning function operates on this core history.
903
+ core_history = [msg for msg in full_current_history if not isinstance(msg, SystemMessage)]
904
+
905
+ llm_input_core_messages: List[AnyMessage]
906
+
907
+ # Prune if it's time (e.g., after every 5th completed iteration, so check for current iteration 6, 11, etc.)
908
+ # Iteration 1-5: no pruning. Iteration 6: prune.
909
+ if iteration_count > 5 and (iteration_count - 1) % 5 == 0:
910
+ print(f"Pruning message history for LLM call at iteration {iteration_count}.")
911
+ llm_input_core_messages = prune_messages_for_llm(core_history, num_recent_to_keep=6)
912
+ else:
913
+ llm_input_core_messages = core_history
914
 
915
+ # Combine system message with the (potentially pruned) core messages
916
+ messages_for_llm = [system_msg] + llm_input_core_messages
917
 
918
+ # Log the messages being sent to LLM for debugging
919
+ # print(f"Messages for LLM (count: {len(messages_for_llm)}):")
920
+ # for i, msg in enumerate(messages_for_llm):
921
+ # print(f" {i}: Type={type(msg).__name__}, Content='{str(msg.content)[:100].replace('\\n', ' ')}...'")
922
+
923
  # Get response from the assistant
924
+ response = chat_with_tools.invoke(messages_for_llm, stop=["Observation:"])
925
  print(f"Assistant response type: {type(response)}")
926
+ # print(f"Response content (first 300 chars): {response.content[:300].replace('\n', ' ')}...")
927
+ content_preview = response.content[:300].replace('\n', ' ')
928
+ print(f"Response content (first 300 chars): {content_preview}...")
929
 
930
  # Extract the action JSON from the response text
931
  action_json = extract_json_from_text(response.content)
932
  print(f"Extracted action JSON: {action_json}")
933
 
934
+ assistant_response_message = AIMessage(content=response.content)
935
+
936
+ state_update: Dict[str, Any] = {
937
+ "messages": [assistant_response_message],
938
+ "iteration_count": iteration_count
939
+ }
940
 
941
  if action_json and "action" in action_json and "action_input" in action_json:
942
  tool_name = action_json["action"]
 
944
  print(f"Extracted tool: {tool_name}")
945
  print(f"Tool input: {tool_input}")
946
 
 
947
  tool_call_id = f"call_{random.randint(1000000, 9999999)}"
948
 
949
+ state_update["current_tool"] = tool_name
950
+ state_update["action_input"] = tool_input
951
+ # state_update["tool_call_id"] = tool_call_id # If needed by your graph
952
+ else:
953
+ print("No tool action found or 'Final Answer' detected in response.")
954
+ state_update["current_tool"] = None
955
+ state_update["action_input"] = None
 
 
 
 
 
 
 
 
 
956
 
957
+ return state_update
 
 
 
 
958
 
959
  def extract_json_from_text(text: str) -> dict:
960
  """Extract JSON from text, handling markdown code blocks."""
 
1464
 
1465
  # Main agent class that integrates with your existing app.py
1466
  class TurboNerd:
1467
+ def __init__(self, max_iterations=25, apify_api_token=None):
1468
  self.graph = create_agent_graph()
1469
  self.tools = tools_config
1470
+ self.max_iterations = max_iterations # Maximum iterations for the graph
1471
+
1472
  # Set Apify API token if provided
1473
  if apify_api_token:
1474
  os.environ["APIFY_API_TOKEN"] = apify_api_token
 
1480
  initial_state = {
1481
  "messages": [HumanMessage(content=f"Question: {question}")],
1482
  "current_tool": None,
1483
+ "action_input": None,
1484
+ "iteration_count": 0 # Initialize iteration_count
1485
  }
1486
 
1487
+ # Run the graph
1488
  print(f"Starting graph execution with question: {question}")
 
1489
 
1490
  try:
1491
+ # Set a reasonable recursion limit based on max_iterations
1492
+ result = self.graph.invoke(initial_state, {"recursion_limit": self.max_iterations})
1493
 
1494
  # Print the final state for debugging
1495
  print(f"Final state keys: {result.keys()}")
 
1500
  print("Final message: ", final_message)
1501
  # Extract just the final answer part
1502
  if "Final Answer:" in final_message:
1503
+ final_answer = final_message.split("Final Answer:", 1)[1].strip()
1504
  return final_answer
1505
 
1506
  return final_message
 
1512
 
1513
  # Example usage:
1514
  if __name__ == "__main__":
1515
+ agent = TurboNerd(max_iterations=25)
1516
+ response = agent("""On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?""")
 
 
 
 
 
 
 
 
 
 
1517
  print("\nFinal Response:")
1518
  print(response)
1519