Spaces:
Sleeping
Sleeping
Implemented prompt pruning
Browse files
agent.py
CHANGED
|
@@ -757,6 +757,7 @@ Action:
|
|
| 757 |
}
|
| 758 |
```
|
| 759 |
Observation: [another tool result will appear here]
|
|
|
|
| 760 |
IMPORTANT: You MUST strictly follow the ReAct pattern (Reasoning, Action, Observation):
|
| 761 |
1. First reason about the problem in the "Thought" section
|
| 762 |
2. Then decide what action to take in the "Action" section (using the tools)
|
|
@@ -764,11 +765,11 @@ IMPORTANT: You MUST strictly follow the ReAct pattern (Reasoning, Action, Observ
|
|
| 764 |
4. Based on the observation, continue with another thought
|
| 765 |
5. This cycle repeats until you have enough information to provide a final answer
|
| 766 |
|
| 767 |
-
NEVER fake or simulate tool output yourself.
|
| 768 |
|
| 769 |
... (this Thought/Action/Observation cycle can repeat as needed) ...
|
| 770 |
Thought: I now know the final answer
|
| 771 |
-
Final Answer:
|
| 772 |
Make sure to follow any formatting instructions given by the user.
|
| 773 |
Now begin! Reminder to ALWAYS use the exact characters `Final Answer:` when you provide a definitive answer."""
|
| 774 |
|
|
@@ -836,32 +837,106 @@ class AgentState(TypedDict, total=False):
|
|
| 836 |
messages: Annotated[list[AnyMessage], add_messages]
|
| 837 |
current_tool: Optional[str]
|
| 838 |
action_input: Optional[ActionInput]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
|
| 840 |
def assistant(state: AgentState) -> Dict[str, Any]:
|
| 841 |
"""Assistant node that processes messages and decides on next action."""
|
| 842 |
print("Assistant Called...\n\n")
|
| 843 |
|
| 844 |
-
|
| 845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 846 |
system_msg = SystemMessage(content=SYSTEM_PROMPT)
|
| 847 |
|
| 848 |
-
#
|
| 849 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
|
| 851 |
-
# Combine system message with
|
| 852 |
-
|
| 853 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
# Get response from the assistant
|
| 855 |
-
response = chat_with_tools.invoke(
|
| 856 |
print(f"Assistant response type: {type(response)}")
|
| 857 |
-
print(f"Response content: {response.content}...")
|
|
|
|
|
|
|
| 858 |
|
| 859 |
# Extract the action JSON from the response text
|
| 860 |
action_json = extract_json_from_text(response.content)
|
| 861 |
print(f"Extracted action JSON: {action_json}")
|
| 862 |
|
| 863 |
-
|
| 864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
|
| 866 |
if action_json and "action" in action_json and "action_input" in action_json:
|
| 867 |
tool_name = action_json["action"]
|
|
@@ -869,31 +944,17 @@ def assistant(state: AgentState) -> Dict[str, Any]:
|
|
| 869 |
print(f"Extracted tool: {tool_name}")
|
| 870 |
print(f"Tool input: {tool_input}")
|
| 871 |
|
| 872 |
-
# Create a tool call ID for the ToolMessage
|
| 873 |
tool_call_id = f"call_{random.randint(1000000, 9999999)}"
|
| 874 |
|
| 875 |
-
|
| 876 |
-
state_update =
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
# Add action_input to state
|
| 883 |
-
if isinstance(tool_input, dict):
|
| 884 |
-
state_update["action_input"] = tool_input
|
| 885 |
-
|
| 886 |
-
return state_update
|
| 887 |
-
|
| 888 |
-
# No tool calls or end of chain indicated by "Final Answer"
|
| 889 |
-
if "Final Answer:" in response.content:
|
| 890 |
-
print("Final answer detected")
|
| 891 |
|
| 892 |
-
return
|
| 893 |
-
"messages": state["messages"] + [assistant_message], # Add full assistant response to history
|
| 894 |
-
"current_tool": None,
|
| 895 |
-
"action_input": None
|
| 896 |
-
}
|
| 897 |
|
| 898 |
def extract_json_from_text(text: str) -> dict:
|
| 899 |
"""Extract JSON from text, handling markdown code blocks."""
|
|
@@ -1403,11 +1464,11 @@ def create_agent_graph() -> StateGraph:
|
|
| 1403 |
|
| 1404 |
# Main agent class that integrates with your existing app.py
|
| 1405 |
class TurboNerd:
|
| 1406 |
-
def __init__(self,
|
| 1407 |
self.graph = create_agent_graph()
|
| 1408 |
self.tools = tools_config
|
| 1409 |
-
self.
|
| 1410 |
-
|
| 1411 |
# Set Apify API token if provided
|
| 1412 |
if apify_api_token:
|
| 1413 |
os.environ["APIFY_API_TOKEN"] = apify_api_token
|
|
@@ -1419,16 +1480,16 @@ class TurboNerd:
|
|
| 1419 |
initial_state = {
|
| 1420 |
"messages": [HumanMessage(content=f"Question: {question}")],
|
| 1421 |
"current_tool": None,
|
| 1422 |
-
"action_input": None
|
|
|
|
| 1423 |
}
|
| 1424 |
|
| 1425 |
-
# Run the graph
|
| 1426 |
print(f"Starting graph execution with question: {question}")
|
| 1427 |
-
start_time = time.time()
|
| 1428 |
|
| 1429 |
try:
|
| 1430 |
-
# Set a reasonable recursion limit
|
| 1431 |
-
result = self.graph.invoke(initial_state, {"recursion_limit":
|
| 1432 |
|
| 1433 |
# Print the final state for debugging
|
| 1434 |
print(f"Final state keys: {result.keys()}")
|
|
@@ -1439,7 +1500,7 @@ class TurboNerd:
|
|
| 1439 |
print("Final message: ", final_message)
|
| 1440 |
# Extract just the final answer part
|
| 1441 |
if "Final Answer:" in final_message:
|
| 1442 |
-
final_answer = final_message.split("Final Answer:")[1].strip()
|
| 1443 |
return final_answer
|
| 1444 |
|
| 1445 |
return final_message
|
|
@@ -1451,18 +1512,8 @@ class TurboNerd:
|
|
| 1451 |
|
| 1452 |
# Example usage:
|
| 1453 |
if __name__ == "__main__":
|
| 1454 |
-
agent = TurboNerd(
|
| 1455 |
-
response = agent("""
|
| 1456 |
-
|
| 1457 |
-
|*|a|b|c|d|e|
|
| 1458 |
-
|---|---|---|---|---|---|
|
| 1459 |
-
|a|a|b|c|b|d|
|
| 1460 |
-
|b|b|c|a|e|c|
|
| 1461 |
-
|c|c|a|b|b|a|
|
| 1462 |
-
|d|b|e|b|e|d|
|
| 1463 |
-
|e|d|b|a|d|c|
|
| 1464 |
-
|
| 1465 |
-
provide the subset of S involved in any possible counter-examples that prove * is not commutative. Provide your answer as a comma separated list of the elements in the set in alphabetical order.""")
|
| 1466 |
print("\nFinal Response:")
|
| 1467 |
print(response)
|
| 1468 |
|
|
|
|
| 757 |
}
|
| 758 |
```
|
| 759 |
Observation: [another tool result will appear here]
|
| 760 |
+
|
| 761 |
IMPORTANT: You MUST strictly follow the ReAct pattern (Reasoning, Action, Observation):
|
| 762 |
1. First reason about the problem in the "Thought" section
|
| 763 |
2. Then decide what action to take in the "Action" section (using the tools)
|
|
|
|
| 765 |
4. Based on the observation, continue with another thought
|
| 766 |
5. This cycle repeats until you have enough information to provide a final answer
|
| 767 |
|
| 768 |
+
NEVER fake or simulate tool output yourself. If you are unable to make progreess in a certain way, try a different tool or a different approach.
|
| 769 |
|
| 770 |
... (this Thought/Action/Observation cycle can repeat as needed) ...
|
| 771 |
Thought: I now know the final answer
|
| 772 |
+
Final Answer: YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
| 773 |
Make sure to follow any formatting instructions given by the user.
|
| 774 |
Now begin! Reminder to ALWAYS use the exact characters `Final Answer:` when you provide a definitive answer."""
|
| 775 |
|
|
|
|
| 837 |
messages: Annotated[list[AnyMessage], add_messages]
|
| 838 |
current_tool: Optional[str]
|
| 839 |
action_input: Optional[ActionInput]
|
| 840 |
+
iteration_count: int # Added to track iterations
|
| 841 |
+
# tool_call_id: Optional[str] # Ensure this is present if used by your graph logic for tools
|
| 842 |
+
|
| 843 |
+
# Add prune_messages_for_llm function
|
| 844 |
+
def prune_messages_for_llm(
|
| 845 |
+
full_history: List[AnyMessage],
|
| 846 |
+
num_recent_to_keep: int = 6 # Keeps roughly 2-3 ReAct turns (Thought/Action, Observation)
|
| 847 |
+
) -> List[AnyMessage]:
|
| 848 |
+
"""
|
| 849 |
+
Prunes the message history for the LLM call.
|
| 850 |
+
This function expects a 'core' history (messages without the initial SystemMessage).
|
| 851 |
+
It keeps the first HumanMessage (original query) and the last `num_recent_to_keep` messages
|
| 852 |
+
from this core history, injecting a condensation note.
|
| 853 |
+
"""
|
| 854 |
+
if not full_history: # full_history here is actually core_history
|
| 855 |
+
return []
|
| 856 |
+
|
| 857 |
+
first_human_message: Optional[HumanMessage] = None
|
| 858 |
+
for msg in full_history: # Iterate over the provided core_history
|
| 859 |
+
if isinstance(msg, HumanMessage):
|
| 860 |
+
first_human_message = msg
|
| 861 |
+
break
|
| 862 |
+
|
| 863 |
+
# If history is too short or no initial human query found in core_history,
|
| 864 |
+
# return core_history as is. The calling function (assistant) will prepend SystemMessage.
|
| 865 |
+
# Threshold considers: first_human (1) + condensation_note (1) + num_recent_to_keep
|
| 866 |
+
if first_human_message is None or len(full_history) < (1 + 1 + num_recent_to_keep):
|
| 867 |
+
return full_history
|
| 868 |
+
|
| 869 |
+
# Pruning is needed for the core_history
|
| 870 |
+
recent_messages_from_core = full_history[-num_recent_to_keep:]
|
| 871 |
+
|
| 872 |
+
pruned_core_list: List[AnyMessage] = []
|
| 873 |
+
|
| 874 |
+
# Add the first human message
|
| 875 |
+
pruned_core_list.append(first_human_message)
|
| 876 |
+
|
| 877 |
+
# Add condensation note
|
| 878 |
+
pruned_core_list.append(
|
| 879 |
+
AIMessage(content="[System note: To manage context length, earlier parts of the conversation have been omitted. The original query and the most recent interactions are preserved.]")
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
# Add recent messages, ensuring not to duplicate the first_human_message if it's in the recent slice
|
| 883 |
+
for msg in recent_messages_from_core:
|
| 884 |
+
if msg is not first_human_message: # Check object identity
|
| 885 |
+
pruned_core_list.append(msg)
|
| 886 |
+
|
| 887 |
+
return pruned_core_list
|
| 888 |
|
| 889 |
def assistant(state: AgentState) -> Dict[str, Any]:
|
| 890 |
"""Assistant node that processes messages and decides on next action."""
|
| 891 |
print("Assistant Called...\n\n")
|
| 892 |
|
| 893 |
+
full_current_history = state["messages"]
|
| 894 |
+
iteration_count = state.get("iteration_count", 0)
|
| 895 |
+
iteration_count += 1 # Increment for the current call
|
| 896 |
+
print(f"Current Iteration: {iteration_count}")
|
| 897 |
+
|
| 898 |
+
# Prepare messages for the LLM
|
| 899 |
system_msg = SystemMessage(content=SYSTEM_PROMPT)
|
| 900 |
|
| 901 |
+
# Core history excludes any SystemMessages found in the accumulated history.
|
| 902 |
+
# The pruning function operates on this core history.
|
| 903 |
+
core_history = [msg for msg in full_current_history if not isinstance(msg, SystemMessage)]
|
| 904 |
+
|
| 905 |
+
llm_input_core_messages: List[AnyMessage]
|
| 906 |
+
|
| 907 |
+
# Prune if it's time (e.g., after every 5th completed iteration, so check for current iteration 6, 11, etc.)
|
| 908 |
+
# Iteration 1-5: no pruning. Iteration 6: prune.
|
| 909 |
+
if iteration_count > 5 and (iteration_count - 1) % 5 == 0:
|
| 910 |
+
print(f"Pruning message history for LLM call at iteration {iteration_count}.")
|
| 911 |
+
llm_input_core_messages = prune_messages_for_llm(core_history, num_recent_to_keep=6)
|
| 912 |
+
else:
|
| 913 |
+
llm_input_core_messages = core_history
|
| 914 |
|
| 915 |
+
# Combine system message with the (potentially pruned) core messages
|
| 916 |
+
messages_for_llm = [system_msg] + llm_input_core_messages
|
| 917 |
|
| 918 |
+
# Log the messages being sent to LLM for debugging
|
| 919 |
+
# print(f"Messages for LLM (count: {len(messages_for_llm)}):")
|
| 920 |
+
# for i, msg in enumerate(messages_for_llm):
|
| 921 |
+
# print(f" {i}: Type={type(msg).__name__}, Content='{str(msg.content)[:100].replace('\\n', ' ')}...'")
|
| 922 |
+
|
| 923 |
# Get response from the assistant
|
| 924 |
+
response = chat_with_tools.invoke(messages_for_llm, stop=["Observation:"])
|
| 925 |
print(f"Assistant response type: {type(response)}")
|
| 926 |
+
# print(f"Response content (first 300 chars): {response.content[:300].replace('\n', ' ')}...")
|
| 927 |
+
content_preview = response.content[:300].replace('\n', ' ')
|
| 928 |
+
print(f"Response content (first 300 chars): {content_preview}...")
|
| 929 |
|
| 930 |
# Extract the action JSON from the response text
|
| 931 |
action_json = extract_json_from_text(response.content)
|
| 932 |
print(f"Extracted action JSON: {action_json}")
|
| 933 |
|
| 934 |
+
assistant_response_message = AIMessage(content=response.content)
|
| 935 |
+
|
| 936 |
+
state_update: Dict[str, Any] = {
|
| 937 |
+
"messages": [assistant_response_message],
|
| 938 |
+
"iteration_count": iteration_count
|
| 939 |
+
}
|
| 940 |
|
| 941 |
if action_json and "action" in action_json and "action_input" in action_json:
|
| 942 |
tool_name = action_json["action"]
|
|
|
|
| 944 |
print(f"Extracted tool: {tool_name}")
|
| 945 |
print(f"Tool input: {tool_input}")
|
| 946 |
|
|
|
|
| 947 |
tool_call_id = f"call_{random.randint(1000000, 9999999)}"
|
| 948 |
|
| 949 |
+
state_update["current_tool"] = tool_name
|
| 950 |
+
state_update["action_input"] = tool_input
|
| 951 |
+
# state_update["tool_call_id"] = tool_call_id # If needed by your graph
|
| 952 |
+
else:
|
| 953 |
+
print("No tool action found or 'Final Answer' detected in response.")
|
| 954 |
+
state_update["current_tool"] = None
|
| 955 |
+
state_update["action_input"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 956 |
|
| 957 |
+
return state_update
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
|
| 959 |
def extract_json_from_text(text: str) -> dict:
|
| 960 |
"""Extract JSON from text, handling markdown code blocks."""
|
|
|
|
| 1464 |
|
| 1465 |
# Main agent class that integrates with your existing app.py
|
| 1466 |
class TurboNerd:
|
| 1467 |
+
def __init__(self, max_iterations=25, apify_api_token=None):
|
| 1468 |
self.graph = create_agent_graph()
|
| 1469 |
self.tools = tools_config
|
| 1470 |
+
self.max_iterations = max_iterations # Maximum iterations for the graph
|
| 1471 |
+
|
| 1472 |
# Set Apify API token if provided
|
| 1473 |
if apify_api_token:
|
| 1474 |
os.environ["APIFY_API_TOKEN"] = apify_api_token
|
|
|
|
| 1480 |
initial_state = {
|
| 1481 |
"messages": [HumanMessage(content=f"Question: {question}")],
|
| 1482 |
"current_tool": None,
|
| 1483 |
+
"action_input": None,
|
| 1484 |
+
"iteration_count": 0 # Initialize iteration_count
|
| 1485 |
}
|
| 1486 |
|
| 1487 |
+
# Run the graph
|
| 1488 |
print(f"Starting graph execution with question: {question}")
|
|
|
|
| 1489 |
|
| 1490 |
try:
|
| 1491 |
+
# Set a reasonable recursion limit based on max_iterations
|
| 1492 |
+
result = self.graph.invoke(initial_state, {"recursion_limit": self.max_iterations})
|
| 1493 |
|
| 1494 |
# Print the final state for debugging
|
| 1495 |
print(f"Final state keys: {result.keys()}")
|
|
|
|
| 1500 |
print("Final message: ", final_message)
|
| 1501 |
# Extract just the final answer part
|
| 1502 |
if "Final Answer:" in final_message:
|
| 1503 |
+
final_answer = final_message.split("Final Answer:", 1)[1].strip()
|
| 1504 |
return final_answer
|
| 1505 |
|
| 1506 |
return final_message
|
|
|
|
| 1512 |
|
| 1513 |
# Example usage:
|
| 1514 |
if __name__ == "__main__":
|
| 1515 |
+
agent = TurboNerd(max_iterations=25)
|
| 1516 |
+
response = agent("""On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1517 |
print("\nFinal Response:")
|
| 1518 |
print(response)
|
| 1519 |
|