GAIA_Agent / agent.py
nikhmr1235's picture
add descriptive docstring to Agent.py
e2e0bc0 verified
"""
This module provides a robust wrapper for executing LangChain agents.
The primary class, BasicAgent, enhances LangChain's AgentExecutor with a
resilient retry mechanism. It is designed to handle intermittent API issues,
specifically retriable Google GenAI API errors (like 429 quota error),
by implementing a retry loop with exponential backoff. Before returning a
final answer, it also processes the output through a validation registry.
"""
import time
from typing import Any, List, Optional
from langchain_classic.agents import Agent, AgentExecutor
from langchain_core.tools import BaseTool
from google import genai
from validators import validator_registry
class BasicAgent:
def __init__(
self,
agent: Agent,
tools: List[BaseTool],
verbose: bool = False,
handle_parsing_errors: bool = True,
max_iterations: int = 9
) -> None:
"""
Initialize with parameters required for AgentExecutor.
"""
self.agent: Agent = agent
self.tools: List[BaseTool] = tools
self.verbose: bool = verbose
self.handle_parsing_errors: bool = handle_parsing_errors
self.max_iterations: int = max_iterations
self.agent_obj = AgentExecutor(
agent=self.agent,
tools=self.tools,
verbose=self.verbose,
handle_parsing_errors=self.handle_parsing_errors,
max_iterations=self.max_iterations
)
def is_retriable(self, e: Exception) -> bool:
# Adjust this check if your error type is different
return isinstance(e, genai.errors.APIError) and getattr(e, "code", None) in {429, 503}
def invoke_with_retry(self,question: str, max_retries: int = 5, initial_delay: float = 10.0) -> str:
current_delay = initial_delay
for attempt in range(max_retries):
try:
result = self.agent_obj.invoke(
{"input": question},
config={"configurable": {"session_id": "test-session"}},
)
# INVOCATION POINT for the validator registry
validated_output = validator_registry.process(
task_description=question,
answer=result['output']
)
return validated_output
except Exception as e:
if self.is_retriable(e):
# Check if the error object provides a specific retry_delay
if hasattr(e, 'retry_delay') and hasattr(e.retry_delay, 'seconds'):
# Use the specific retry_delay provided by the API
current_delay = float(e.retry_delay.seconds)
print(f"Quota error (attempt {attempt+1}/{max_retries}). API suggested retry after {current_delay} seconds.", flush=True)
else:
# Fallback to exponential backoff if no specific delay is provided
print(f"Quota error (attempt {attempt+1}/{max_retries}). Retrying in {current_delay} seconds with exponential backoff.", flush=True)
current_delay *= 2 # Exponential backoff
time.sleep(current_delay)
else:
# If it's not a retriable error, re-raise it
raise
# If all retries fail, raise a RuntimeError
raise RuntimeError(f"Max retries ({max_retries}) exceeded due to persistent quota errors or other retriable issues.")
def __call__(self, question: str) -> str:
"""
Allows the instance to be called directly to get an AgentExecutor.
"""
return self.invoke_with_retry(question)