Update functions/agent_helper_functions.py
Browse files
functions/agent_helper_functions.py
CHANGED
|
@@ -1,124 +1,80 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
import os
|
| 4 |
import time
|
| 5 |
import json
|
| 6 |
import logging
|
| 7 |
-
from smolagents import
|
| 8 |
-
from configuration import CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT
|
| 9 |
|
| 10 |
-
# Logger per questo modulo
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
SUMMARIZER_MODEL =
|
| 15 |
-
"deepseek-ai/DeepSeek-V3",
|
| 16 |
-
provider="together",
|
| 17 |
-
api_key=os.getenv("TOGETHER_API_KEY"),
|
| 18 |
-
temperature=0,
|
| 19 |
-
max_tokens=8000
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
|
| 23 |
def check_reasoning(final_answer: str, agent_memory) -> bool:
|
| 24 |
-
"""Checks the reasoning and plot of the agent's final answer."""
|
| 25 |
prompt = (
|
| 26 |
f"Here is a user-given task and the agent steps: "
|
| 27 |
f"{agent_memory.get_succinct_steps()}. "
|
| 28 |
"Please check that the reasoning process and answer are correct. "
|
| 29 |
"First list reasons why yes/no, then write your final decision: "
|
| 30 |
-
"PASS in caps lock if
|
| 31 |
f"Final answer: {final_answer}"
|
| 32 |
)
|
| 33 |
-
|
| 34 |
-
messages = [
|
| 35 |
-
{
|
| 36 |
-
"role": "user",
|
| 37 |
-
"content": [
|
| 38 |
-
{
|
| 39 |
-
"type": "text",
|
| 40 |
-
"text": prompt,
|
| 41 |
-
}
|
| 42 |
-
],
|
| 43 |
-
}
|
| 44 |
-
]
|
| 45 |
-
|
| 46 |
feedback = CHECK_MODEL(messages).content
|
| 47 |
-
print("Feedback:
|
| 48 |
-
|
| 49 |
if "FAIL" in feedback:
|
| 50 |
raise Exception(feedback)
|
| 51 |
return True
|
| 52 |
|
| 53 |
-
|
| 54 |
-
def step_memory_cap(memory_step: ActionStep, agent: CodeAgent) -> None:
|
| 55 |
-
"""Removes old steps from agent memory to keep context length under control."""
|
| 56 |
-
task_step = agent.memory.steps[0]
|
| 57 |
-
planning_step = agent.memory.steps[1]
|
| 58 |
-
latest_step = agent.memory.steps[-1]
|
| 59 |
-
|
| 60 |
-
# Keep only the first two + latest
|
| 61 |
if len(agent.memory.steps) > 2:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
if summary:
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
latest_step.model_input_messages[0],
|
| 79 |
{
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
}]
|
| 85 |
}
|
| 86 |
]
|
| 87 |
-
agent.memory.steps = [
|
| 88 |
-
agent.memory.steps[0].model_input_messages =
|
| 89 |
-
logger.debug('Old messages summarized into new context.')
|
| 90 |
-
|
| 91 |
|
| 92 |
def summarize_old_messages(messages: list) -> str | None:
|
| 93 |
-
"""Summarizes old messages to keep context length under control using DeepSeek."""
|
| 94 |
if not messages:
|
| 95 |
return None
|
| 96 |
-
|
| 97 |
prompt = (
|
| 98 |
"Summarize the following interaction between an AI agent and a user "
|
| 99 |
-
"in plain text (not JSON): "
|
| 100 |
-
+ json.dumps(messages)
|
| 101 |
)
|
| 102 |
-
chat_input = [
|
| 103 |
-
{
|
| 104 |
-
"role": "user",
|
| 105 |
-
"content": [
|
| 106 |
-
{"type": "text", "text": prompt}
|
| 107 |
-
]
|
| 108 |
-
}
|
| 109 |
-
]
|
| 110 |
-
|
| 111 |
try:
|
| 112 |
-
|
| 113 |
-
return
|
| 114 |
except Exception as e:
|
| 115 |
logger.error("Error during summarization: %s", e)
|
| 116 |
return None
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
"""Waits to prevent hitting API rate limits."""
|
| 121 |
-
logger.info('Waiting %d seconds (step %d)', STEP_WAIT,
|
| 122 |
-
memory_step.step_number)
|
| 123 |
time.sleep(STEP_WAIT)
|
| 124 |
-
return True
|
|
|
|
| 1 |
+
# functions/agent_helper_functions.py
|
| 2 |
+
|
| 3 |
+
"""Helper functions for the GAIA agent."""
|
| 4 |
|
|
|
|
| 5 |
import time
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
+
from smolagents import ActionStep, MessageRole
|
| 9 |
+
from configuration import MODEL, CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT
|
| 10 |
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
# Usa lo stesso MODEL anche per i riassunti
|
| 14 |
+
SUMMARIZER_MODEL = MODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def check_reasoning(final_answer: str, agent_memory) -> bool:
|
|
|
|
| 17 |
prompt = (
|
| 18 |
f"Here is a user-given task and the agent steps: "
|
| 19 |
f"{agent_memory.get_succinct_steps()}. "
|
| 20 |
"Please check that the reasoning process and answer are correct. "
|
| 21 |
"First list reasons why yes/no, then write your final decision: "
|
| 22 |
+
"PASS in caps lock if satisfactory, FAIL if not. "
|
| 23 |
f"Final answer: {final_answer}"
|
| 24 |
)
|
| 25 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
feedback = CHECK_MODEL(messages).content
|
| 27 |
+
print("Feedback:", feedback)
|
|
|
|
| 28 |
if "FAIL" in feedback:
|
| 29 |
raise Exception(feedback)
|
| 30 |
return True
|
| 31 |
|
| 32 |
+
def step_memory_cap(memory_step: ActionStep, agent) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if len(agent.memory.steps) > 2:
|
| 34 |
+
first, second, last = (
|
| 35 |
+
agent.memory.steps[0],
|
| 36 |
+
agent.memory.steps[1],
|
| 37 |
+
agent.memory.steps[-1]
|
| 38 |
+
)
|
| 39 |
+
agent.memory.steps = [first, second, last]
|
| 40 |
+
|
| 41 |
+
latest = agent.memory.steps[-1]
|
| 42 |
+
if latest.token_usage.total_tokens > TOKEN_LIMITER:
|
| 43 |
+
logger.info(
|
| 44 |
+
"Token usage %d > %d, summarizing",
|
| 45 |
+
latest.token_usage.total_tokens, TOKEN_LIMITER
|
| 46 |
+
)
|
| 47 |
+
summary = summarize_old_messages(latest.model_input_messages[1:])
|
| 48 |
if summary:
|
| 49 |
+
new_msgs = [
|
| 50 |
+
latest.model_input_messages[0],
|
|
|
|
| 51 |
{
|
| 52 |
+
"role": MessageRole.USER,
|
| 53 |
+
"content": [{
|
| 54 |
+
"type": "text",
|
| 55 |
+
"text": f"Here is a summary of your investigation so far: {summary}"
|
| 56 |
}]
|
| 57 |
}
|
| 58 |
]
|
| 59 |
+
agent.memory.steps = [agent.memory.steps[0]]
|
| 60 |
+
agent.memory.steps[0].model_input_messages = new_msgs
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def summarize_old_messages(messages: list) -> str | None:
|
|
|
|
| 63 |
if not messages:
|
| 64 |
return None
|
|
|
|
| 65 |
prompt = (
|
| 66 |
"Summarize the following interaction between an AI agent and a user "
|
| 67 |
+
"in plain text (not JSON): " + json.dumps(messages)
|
|
|
|
| 68 |
)
|
| 69 |
+
chat_input = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
try:
|
| 71 |
+
resp = SUMMARIZER_MODEL(chat_input)
|
| 72 |
+
return resp.content
|
| 73 |
except Exception as e:
|
| 74 |
logger.error("Error during summarization: %s", e)
|
| 75 |
return None
|
| 76 |
|
| 77 |
+
def step_wait(memory_step: ActionStep, agent) -> bool:
|
| 78 |
+
logger.info("Waiting %d seconds (step %d)", STEP_WAIT, memory_step.step_number)
|
|
|
|
|
|
|
|
|
|
| 79 |
time.sleep(STEP_WAIT)
|
| 80 |
+
return True
|