AlessandroMasala commited on
Commit
6fc7a1b
·
verified ·
1 Parent(s): 4b20138

Update functions/agent_helper_functions.py

Browse files
Files changed (1) hide show
  1. functions/agent_helper_functions.py +40 -84
functions/agent_helper_functions.py CHANGED
@@ -1,124 +1,80 @@
1
- """Helper functions for the agent(s) in the GAIA question answering system."""
 
 
2
 
3
- import os
4
  import time
5
  import json
6
  import logging
7
- from smolagents import CodeAgent, ActionStep, MessageRole, InferenceClientModel
8
- from configuration import CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT
9
 
10
- # Logger per questo modulo
11
  logger = logging.getLogger(__name__)
12
 
13
- # Model dedicato alla summarizzazione dei messaggi
14
- SUMMARIZER_MODEL = InferenceClientModel(
15
- "deepseek-ai/DeepSeek-V3",
16
- provider="together",
17
- api_key=os.getenv("TOGETHER_API_KEY"),
18
- temperature=0,
19
- max_tokens=8000
20
- )
21
-
22
 
23
  def check_reasoning(final_answer: str, agent_memory) -> bool:
24
- """Checks the reasoning and plot of the agent's final answer."""
25
  prompt = (
26
  f"Here is a user-given task and the agent steps: "
27
  f"{agent_memory.get_succinct_steps()}. "
28
  "Please check that the reasoning process and answer are correct. "
29
  "First list reasons why yes/no, then write your final decision: "
30
- "PASS in caps lock if it is satisfactory, FAIL if it is not. "
31
  f"Final answer: {final_answer}"
32
  )
33
-
34
- messages = [
35
- {
36
- "role": "user",
37
- "content": [
38
- {
39
- "type": "text",
40
- "text": prompt,
41
- }
42
- ],
43
- }
44
- ]
45
-
46
  feedback = CHECK_MODEL(messages).content
47
- print("Feedback: ", feedback)
48
-
49
  if "FAIL" in feedback:
50
  raise Exception(feedback)
51
  return True
52
 
53
-
54
- def step_memory_cap(memory_step: ActionStep, agent: CodeAgent) -> None:
55
- """Removes old steps from agent memory to keep context length under control."""
56
- task_step = agent.memory.steps[0]
57
- planning_step = agent.memory.steps[1]
58
- latest_step = agent.memory.steps[-1]
59
-
60
- # Keep only the first two + latest
61
  if len(agent.memory.steps) > 2:
62
- agent.memory.steps = [task_step, planning_step, latest_step]
63
-
64
- logger.info('Agent memory has %d steps', len(agent.memory.steps))
65
- logger.info('Latest step is %d', memory_step.step_number)
66
- logger.info('Messages in latest step: %d',
67
- len(latest_step.model_input_messages))
68
- logger.info('Token usage: %d', latest_step.token_usage.total_tokens)
69
-
70
- # If troppo token usage, faccio il summary
71
- if latest_step.token_usage.total_tokens > TOKEN_LIMITER:
72
- logger.info('Token usage %d > %d, summarizing old messages',
73
- latest_step.token_usage.total_tokens, TOKEN_LIMITER)
74
- summary = summarize_old_messages(latest_step.model_input_messages[1:])
 
75
  if summary:
76
- # Ricostruisco il solo step con il summary
77
- new_messages = [
78
- latest_step.model_input_messages[0],
79
  {
80
- 'role': MessageRole.USER,
81
- 'content': [{
82
- 'type': 'text',
83
- 'text': f'Here is a summary of your investigation so far: {summary}'
84
  }]
85
  }
86
  ]
87
- agent.memory.steps = [task_step]
88
- agent.memory.steps[0].model_input_messages = new_messages
89
- logger.debug('Old messages summarized into new context.')
90
-
91
 
92
  def summarize_old_messages(messages: list) -> str | None:
93
- """Summarizes old messages to keep context length under control using DeepSeek."""
94
  if not messages:
95
  return None
96
-
97
  prompt = (
98
  "Summarize the following interaction between an AI agent and a user "
99
- "in plain text (not JSON): "
100
- + json.dumps(messages)
101
  )
102
- chat_input = [
103
- {
104
- "role": "user",
105
- "content": [
106
- {"type": "text", "text": prompt}
107
- ]
108
- }
109
- ]
110
-
111
  try:
112
- response = SUMMARIZER_MODEL(chat_input)
113
- return response.content
114
  except Exception as e:
115
  logger.error("Error during summarization: %s", e)
116
  return None
117
 
118
-
119
- def step_wait(memory_step: ActionStep, agent: CodeAgent) -> bool:
120
- """Waits to prevent hitting API rate limits."""
121
- logger.info('Waiting %d seconds (step %d)', STEP_WAIT,
122
- memory_step.step_number)
123
  time.sleep(STEP_WAIT)
124
- return True
 
1
+ # functions/agent_helper_functions.py
2
+
3
+ """Helper functions for the GAIA agent."""
4
 
 
5
  import time
6
  import json
7
  import logging
8
+ from smolagents import ActionStep, MessageRole
9
+ from configuration import MODEL, CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT
10
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Usa lo stesso MODEL anche per i riassunti
14
+ SUMMARIZER_MODEL = MODEL
 
 
 
 
 
 
 
15
 
16
  def check_reasoning(final_answer: str, agent_memory) -> bool:
 
17
  prompt = (
18
  f"Here is a user-given task and the agent steps: "
19
  f"{agent_memory.get_succinct_steps()}. "
20
  "Please check that the reasoning process and answer are correct. "
21
  "First list reasons why yes/no, then write your final decision: "
22
+ "PASS in caps lock if satisfactory, FAIL if not. "
23
  f"Final answer: {final_answer}"
24
  )
25
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
 
 
 
 
 
 
 
 
 
 
 
 
26
  feedback = CHECK_MODEL(messages).content
27
+ print("Feedback:", feedback)
 
28
  if "FAIL" in feedback:
29
  raise Exception(feedback)
30
  return True
31
 
32
+ def step_memory_cap(memory_step: ActionStep, agent) -> None:
 
 
 
 
 
 
 
33
  if len(agent.memory.steps) > 2:
34
+ first, second, last = (
35
+ agent.memory.steps[0],
36
+ agent.memory.steps[1],
37
+ agent.memory.steps[-1]
38
+ )
39
+ agent.memory.steps = [first, second, last]
40
+
41
+ latest = agent.memory.steps[-1]
42
+ if latest.token_usage.total_tokens > TOKEN_LIMITER:
43
+ logger.info(
44
+ "Token usage %d > %d, summarizing",
45
+ latest.token_usage.total_tokens, TOKEN_LIMITER
46
+ )
47
+ summary = summarize_old_messages(latest.model_input_messages[1:])
48
  if summary:
49
+ new_msgs = [
50
+ latest.model_input_messages[0],
 
51
  {
52
+ "role": MessageRole.USER,
53
+ "content": [{
54
+ "type": "text",
55
+ "text": f"Here is a summary of your investigation so far: {summary}"
56
  }]
57
  }
58
  ]
59
+ agent.memory.steps = [agent.memory.steps[0]]
60
+ agent.memory.steps[0].model_input_messages = new_msgs
 
 
61
 
62
  def summarize_old_messages(messages: list) -> str | None:
 
63
  if not messages:
64
  return None
 
65
  prompt = (
66
  "Summarize the following interaction between an AI agent and a user "
67
+ "in plain text (not JSON): " + json.dumps(messages)
 
68
  )
69
+ chat_input = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
 
 
 
 
 
 
 
 
70
  try:
71
+ resp = SUMMARIZER_MODEL(chat_input)
72
+ return resp.content
73
  except Exception as e:
74
  logger.error("Error during summarization: %s", e)
75
  return None
76
 
77
+ def step_wait(memory_step: ActionStep, agent) -> bool:
78
+ logger.info("Waiting %d seconds (step %d)", STEP_WAIT, memory_step.step_number)
 
 
 
79
  time.sleep(STEP_WAIT)
80
+ return True