wanderlust.ai / src /wanderlust_ai /core /base_agent.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Base Agent Class for Multi-Agent Travel Planner
This module demonstrates:
- Async/await patterns for API calls
- Object-oriented programming principles
- Proper context management
- Error handling and logging
- Inheritance patterns for specialized agents
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union, AsyncGenerator
from dataclasses import dataclass
import httpx
from pydantic import BaseModel, Field
from ..utils.logging import get_logger
from ..utils.security import ErrorResponse, ErrorType
@dataclass
class AgentConfig:
"""Configuration for AI agents."""
name: str
api_base_url: str
api_key: str
timeout: int = 30
max_retries: int = 3
retry_delay: float = 1.0
rate_limit_per_minute: int = 60
class APIResponse(BaseModel):
"""Standardized API response model."""
success: bool = Field(..., description="Whether the API call was successful")
data: Optional[Dict[str, Any]] = Field(None, description="Response data")
error: Optional[str] = Field(None, description="Error message if any")
status_code: int = Field(..., description="HTTP status code")
response_time_ms: float = Field(..., description="Response time in milliseconds")
timestamp: datetime = Field(default_factory=datetime.now, description="Response timestamp")
class BaseAgent(ABC):
"""
Base class for AI agents making HTTP API calls.
This class demonstrates:
- Async/await patterns for non-blocking I/O
- Context management for resource cleanup
- Error handling and retry logic
- Logging and monitoring
- Inheritance patterns for specialized agents
"""
def __init__(self, config: AgentConfig):
"""
Initialize the base agent.
Args:
config: Agent configuration including API details
"""
self.config = config
self.logger = get_logger(f"agent.{config.name}")
self._client: Optional[httpx.AsyncClient] = None
self._request_count = 0
self._last_request_time: Optional[datetime] = None
# Rate limiting
self._rate_limit_tokens = config.rate_limit_per_minute
self._rate_limit_reset = datetime.now() + timedelta(minutes=1)
self.logger.info(f"Initialized {config.name} agent")
# =========================================================================
# ASYNC CONTEXT MANAGEMENT
# =========================================================================
async def __aenter__(self):
"""Async context manager entry."""
await self._setup_client()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit with proper cleanup."""
await self._cleanup_client()
if exc_type:
self.logger.error(f"Agent {self.config.name} exited with error: {exc_val}")
else:
self.logger.info(f"Agent {self.config.name} completed successfully")
@asynccontextmanager
async def get_client(self) -> AsyncGenerator[httpx.AsyncClient, None]:
"""
Context manager for HTTP client.
This ensures proper resource management and cleanup.
"""
if not self._client:
await self._setup_client()
try:
yield self._client
finally:
# Client cleanup is handled by __aexit__
pass
# =========================================================================
# CLIENT MANAGEMENT
# =========================================================================
async def _setup_client(self):
"""Setup HTTP client with proper configuration."""
if self._client:
return
self._client = httpx.AsyncClient(
base_url=self.config.api_base_url,
timeout=httpx.Timeout(self.config.timeout),
headers={
"Authorization": f"Bearer {self.config.api_key}",
"User-Agent": f"WanderlustAI/{self.config.name}",
"Content-Type": "application/json"
}
)
self.logger.info(f"HTTP client setup for {self.config.name}")
async def _cleanup_client(self):
"""Cleanup HTTP client resources."""
if self._client:
await self._client.aclose()
self._client = None
self.logger.info(f"HTTP client cleaned up for {self.config.name}")
# =========================================================================
# RATE LIMITING
# =========================================================================
async def _check_rate_limit(self):
"""Check and enforce rate limiting."""
now = datetime.now()
# Reset tokens if minute has passed
if now >= self._rate_limit_reset:
self._rate_limit_tokens = self.config.rate_limit_per_minute
self._rate_limit_reset = now + timedelta(minutes=1)
self.logger.debug(f"Rate limit reset for {self.config.name}")
# Check if we have tokens available
if self._rate_limit_tokens <= 0:
wait_time = (self._rate_limit_reset - now).total_seconds()
self.logger.warning(f"Rate limit exceeded for {self.config.name}, waiting {wait_time:.1f}s")
await asyncio.sleep(wait_time)
await self._check_rate_limit() # Recursive check after wait
# Consume a token
self._rate_limit_tokens -= 1
self.logger.debug(f"Rate limit: {self._rate_limit_tokens} tokens remaining for {self.config.name}")
# =========================================================================
# HTTP REQUEST METHODS
# =========================================================================
async def _make_request(
self,
method: str,
endpoint: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> APIResponse:
"""
Make an HTTP request with proper error handling and retry logic.
This demonstrates async/await patterns and robust error handling.
"""
start_time = datetime.now()
# Rate limiting check
await self._check_rate_limit()
# Prepare request
request_headers = {}
if headers:
request_headers.update(headers)
# Retry logic
for attempt in range(self.config.max_retries + 1):
try:
async with self.get_client() as client:
self.logger.debug(f"{self.config.name}: {method} {endpoint} (attempt {attempt + 1})")
# Make the request
response = await client.request(
method=method,
url=endpoint,
json=data,
params=params,
headers=request_headers
)
# Calculate response time
response_time = (datetime.now() - start_time).total_seconds() * 1000
# Update request tracking
self._request_count += 1
self._last_request_time = datetime.now()
# Create standardized response
api_response = APIResponse(
success=response.status_code < 400,
data=response.json() if response.status_code < 400 else None,
error=response.text if response.status_code >= 400 else None,
status_code=response.status_code,
response_time_ms=response_time
)
if api_response.success:
self.logger.info(f"{self.config.name}: {method} {endpoint} - {response.status_code} ({response_time:.1f}ms)")
else:
self.logger.warning(f"{self.config.name}: {method} {endpoint} - {response.status_code} ({response_time:.1f}ms)")
return api_response
except httpx.TimeoutException as e:
self.logger.warning(f"{self.config.name}: Timeout on {method} {endpoint} (attempt {attempt + 1})")
if attempt == self.config.max_retries:
return APIResponse(
success=False,
error=f"Request timeout after {self.config.max_retries + 1} attempts",
status_code=408,
response_time_ms=(datetime.now() - start_time).total_seconds() * 1000
)
except httpx.HTTPStatusError as e:
self.logger.warning(f"{self.config.name}: HTTP error {e.response.status_code} on {method} {endpoint}")
if attempt == self.config.max_retries:
return APIResponse(
success=False,
error=f"HTTP error: {e.response.status_code}",
status_code=e.response.status_code,
response_time_ms=(datetime.now() - start_time).total_seconds() * 1000
)
except Exception as e:
self.logger.error(f"{self.config.name}: Unexpected error on {method} {endpoint}: {e}")
if attempt == self.config.max_retries:
return APIResponse(
success=False,
error=f"Unexpected error: {str(e)}",
status_code=500,
response_time_ms=(datetime.now() - start_time).total_seconds() * 1000
)
# Wait before retry
if attempt < self.config.max_retries:
wait_time = self.config.retry_delay * (2 ** attempt) # Exponential backoff
self.logger.debug(f"{self.config.name}: Waiting {wait_time}s before retry")
await asyncio.sleep(wait_time)
# This should never be reached, but just in case
return APIResponse(
success=False,
error="Max retries exceeded",
status_code=500,
response_time_ms=(datetime.now() - start_time).total_seconds() * 1000
)
# =========================================================================
# CONVENIENCE METHODS
# =========================================================================
async def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> APIResponse:
"""Make a GET request."""
return await self._make_request("GET", endpoint, params=params)
async def post(self, endpoint: str, data: Optional[Dict[str, Any]] = None) -> APIResponse:
"""Make a POST request."""
return await self._make_request("POST", endpoint, data=data)
async def put(self, endpoint: str, data: Optional[Dict[str, Any]] = None) -> APIResponse:
"""Make a PUT request."""
return await self._make_request("PUT", endpoint, data=data)
async def delete(self, endpoint: str) -> APIResponse:
"""Make a DELETE request."""
return await self._make_request("DELETE", endpoint)
# =========================================================================
# ABSTRACT METHODS (Must be implemented by subclasses)
# =========================================================================
@abstractmethod
async def search(self, query: Dict[str, Any]) -> APIResponse:
"""
Search for data using the agent's specific API.
This method must be implemented by each specialized agent.
"""
pass
@abstractmethod
async def get_details(self, item_id: str) -> APIResponse:
"""
Get detailed information about a specific item.
This method must be implemented by each specialized agent.
"""
pass
# =========================================================================
# UTILITY METHODS
# =========================================================================
def get_stats(self) -> Dict[str, Any]:
"""Get agent statistics."""
return {
"name": self.config.name,
"request_count": self._request_count,
"last_request_time": self._last_request_time,
"rate_limit_tokens": self._rate_limit_tokens,
"rate_limit_reset": self._rate_limit_reset
}
async def health_check(self) -> bool:
"""Check if the agent is healthy and can make API calls."""
try:
# Try a simple request to check connectivity
response = await self.get("/health")
return response.success
except Exception as e:
self.logger.error(f"Health check failed for {self.config.name}: {e}")
return False
# =========================================================================
# ERROR HANDLING
# =========================================================================
def _handle_error(self, error: Exception, context: str = "") -> ErrorResponse:
"""Handle errors and create standardized error responses."""
error_msg = f"{self.config.name} agent error"
if context:
error_msg += f" in {context}"
self.logger.error(f"{error_msg}: {error}")
return ErrorResponse(
error_type=ErrorType.API_ERROR,
error_code="AGENT_ERROR",
message=error_msg,
details=str(error),
request_id=f"{self.config.name}_{datetime.now().timestamp()}"
)
# =========================================================================
# CLASS METHODS
# =========================================================================
@classmethod
def create_config(
cls,
name: str,
api_base_url: str,
api_key: str,
**kwargs
) -> AgentConfig:
"""Create agent configuration with defaults."""
return AgentConfig(
name=name,
api_base_url=api_base_url,
api_key=api_key,
**kwargs
)
# =========================================================================
# STRING REPRESENTATION
# =========================================================================
def __str__(self) -> str:
return f"{self.__class__.__name__}(name={self.config.name})"
def __repr__(self) -> str:
return f"{self.__class__.__name__}(config={self.config})"