Spaces:
Sleeping
Sleeping
Switched to single agent powered by GPT-4.1, added step wait function to avoid hitting the OpenAI API rate limit.
b4e2809 verified | '''Helper functions for the agent(s) in the GAIA question answering system.''' | |
| import os | |
| import time | |
| import json | |
| import logging | |
| from openai import OpenAI | |
| from smolagents import CodeAgent, ActionStep, MessageRole | |
| from configuration import CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT | |
| # Get logger for this module | |
| logger = logging.getLogger(__name__) | |
| def check_reasoning(final_answer:str, agent_memory): | |
| """Checks the reasoning and plot of the agent's final answer.""" | |
| prompt = ( | |
| f"Here is a user-given task and the agent steps: {agent_memory.get_succinct_steps()}. " + | |
| "Please check that the reasoning process and answer are correct. " + | |
| "Do they correctly answer the given task? " + | |
| "First list reasons why yes/no, then write your final decision: " + | |
| "PASS in caps lock if it is satisfactory, FAIL if it is not. " + | |
| f"Final answer: {str(final_answer)}" | |
| ) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": prompt, | |
| } | |
| ], | |
| } | |
| ] | |
| output = CHECK_MODEL(messages).content | |
| print("Feedback: ", output) | |
| if "FAIL" in output: | |
| raise Exception(output) # pylint:disable=broad-exception-raised | |
| return True | |
| def step_memory_cap(memory_step: ActionStep, agent: CodeAgent) -> None: | |
| '''Removes old steps from agent memory to keep context length under control.''' | |
| task_step = agent.memory.steps[0] | |
| planning_step = agent.memory.steps[1] | |
| latest_step = agent.memory.steps[-1] | |
| if len(agent.memory.steps) > 2: | |
| agent.memory.steps = [task_step, planning_step, latest_step] | |
| logger.info('Agent memory has %d steps', len(agent.memory.steps)) | |
| logger.info('Latest step is step %d', memory_step.step_number) | |
| logger.info('Contains: %s messages', len(agent.memory.steps[-1].model_input_messages)) | |
| logger.info('Token usage: %s', agent.memory.steps[-1].token_usage.total_tokens) | |
| for message in agent.memory.steps[-1].model_input_messages: | |
| logger.debug(' Role: %s: %s', message['role'], message['content'][:100]) | |
| token_usage = agent.memory.steps[-1].token_usage.total_tokens | |
| if token_usage > TOKEN_LIMITER: | |
| logger.info('Token usage is %d, summarizing old messages', token_usage) | |
| summary = summarize_old_messages( | |
| agent.memory.steps[-1].model_input_messages[1:] | |
| ) | |
| if summary is not None: | |
| new_messages = [agent.memory.steps[-1].model_input_messages[0]] | |
| new_messages.append({ | |
| 'role': MessageRole.USER, | |
| 'content': [{ | |
| 'type': 'text', | |
| 'text': f'Here is a summary of your investigation so far: {summary}' | |
| }] | |
| }) | |
| agent.memory.steps = [agent.memory.steps[0]] | |
| agent.memory.steps[0].model_input_messages = new_messages | |
| for message in agent.memory.steps[0].model_input_messages: | |
| logger.debug(' Role: %s: %s', message['role'], message['content'][:100]) | |
| def summarize_old_messages(messages: dict) -> dict: | |
| '''Summarizes old messages to keep context length under control.''' | |
| client = OpenAI(api_key=os.environ['MODAL_API_KEY']) | |
| client.base_url = ( | |
| 'https://gperdrizet--vllm-openai-compatible-summarization-serve.modal.run/v1' | |
| ) | |
| # Default to first avalible model | |
| model = client.models.list().data[0] | |
| model_id = model.id | |
| messages = [ | |
| { | |
| 'role': 'system', | |
| 'content': ('Summarize the following interaction between an AI agent and a user.' + | |
| f'Return the summary formatted as text, not as JSON: {json.dumps(messages)}') | |
| } | |
| ] | |
| completion_args = { | |
| 'model': model_id, | |
| 'messages': messages, | |
| } | |
| try: | |
| response = client.chat.completions.create(**completion_args) | |
| except Exception as e: # pylint: disable=broad-exception-caught | |
| response = None | |
| logger.error('Error during Modal API call: %s', e) | |
| if response is not None: | |
| summary = response.choices[0].message.content | |
| else: | |
| summary = None | |
| return summary | |
| def step_wait(memory_step: ActionStep, agent: CodeAgent) -> None: | |
| '''Waits for a while to prevent hitting API rate limits.''' | |
| logger.info('Waiting for %d seconds to prevent hitting API rate limits', STEP_WAIT) | |
| logger.info('Current step is %d', memory_step.step_number) | |
| logger.info('Current agent has %d steps', len(agent.memory.steps)) | |
| time.sleep(STEP_WAIT) | |
| return True | |