Spaces:
Sleeping
Sleeping
| """ | |
| Base Agent Class for Multi-Agent Travel Planner | |
| This module demonstrates: | |
| - Async/await patterns for API calls | |
| - Object-oriented programming principles | |
| - Proper context management | |
| - Error handling and logging | |
| - Inheritance patterns for specialized agents | |
| """ | |
| import asyncio | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, timedelta | |
| from typing import Any, Dict, List, Optional, Union, AsyncGenerator | |
| from dataclasses import dataclass | |
| import httpx | |
| from pydantic import BaseModel, Field | |
| from ..utils.logging import get_logger | |
| from ..utils.security import ErrorResponse, ErrorType | |
| class AgentConfig: | |
| """Configuration for AI agents.""" | |
| name: str | |
| api_base_url: str | |
| api_key: str | |
| timeout: int = 30 | |
| max_retries: int = 3 | |
| retry_delay: float = 1.0 | |
| rate_limit_per_minute: int = 60 | |
| class APIResponse(BaseModel): | |
| """Standardized API response model.""" | |
| success: bool = Field(..., description="Whether the API call was successful") | |
| data: Optional[Dict[str, Any]] = Field(None, description="Response data") | |
| error: Optional[str] = Field(None, description="Error message if any") | |
| status_code: int = Field(..., description="HTTP status code") | |
| response_time_ms: float = Field(..., description="Response time in milliseconds") | |
| timestamp: datetime = Field(default_factory=datetime.now, description="Response timestamp") | |
| class BaseAgent(ABC): | |
| """ | |
| Base class for AI agents making HTTP API calls. | |
| This class demonstrates: | |
| - Async/await patterns for non-blocking I/O | |
| - Context management for resource cleanup | |
| - Error handling and retry logic | |
| - Logging and monitoring | |
| - Inheritance patterns for specialized agents | |
| """ | |
| def __init__(self, config: AgentConfig): | |
| """ | |
| Initialize the base agent. | |
| Args: | |
| config: Agent configuration including API details | |
| """ | |
| self.config = config | |
| self.logger = get_logger(f"agent.{config.name}") | |
| self._client: Optional[httpx.AsyncClient] = None | |
| self._request_count = 0 | |
| self._last_request_time: Optional[datetime] = None | |
| # Rate limiting | |
| self._rate_limit_tokens = config.rate_limit_per_minute | |
| self._rate_limit_reset = datetime.now() + timedelta(minutes=1) | |
| self.logger.info(f"Initialized {config.name} agent") | |
| # ========================================================================= | |
| # ASYNC CONTEXT MANAGEMENT | |
| # ========================================================================= | |
| async def __aenter__(self): | |
| """Async context manager entry.""" | |
| await self._setup_client() | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| """Async context manager exit with proper cleanup.""" | |
| await self._cleanup_client() | |
| if exc_type: | |
| self.logger.error(f"Agent {self.config.name} exited with error: {exc_val}") | |
| else: | |
| self.logger.info(f"Agent {self.config.name} completed successfully") | |
| async def get_client(self) -> AsyncGenerator[httpx.AsyncClient, None]: | |
| """ | |
| Context manager for HTTP client. | |
| This ensures proper resource management and cleanup. | |
| """ | |
| if not self._client: | |
| await self._setup_client() | |
| try: | |
| yield self._client | |
| finally: | |
| # Client cleanup is handled by __aexit__ | |
| pass | |
| # ========================================================================= | |
| # CLIENT MANAGEMENT | |
| # ========================================================================= | |
| async def _setup_client(self): | |
| """Setup HTTP client with proper configuration.""" | |
| if self._client: | |
| return | |
| self._client = httpx.AsyncClient( | |
| base_url=self.config.api_base_url, | |
| timeout=httpx.Timeout(self.config.timeout), | |
| headers={ | |
| "Authorization": f"Bearer {self.config.api_key}", | |
| "User-Agent": f"WanderlustAI/{self.config.name}", | |
| "Content-Type": "application/json" | |
| } | |
| ) | |
| self.logger.info(f"HTTP client setup for {self.config.name}") | |
| async def _cleanup_client(self): | |
| """Cleanup HTTP client resources.""" | |
| if self._client: | |
| await self._client.aclose() | |
| self._client = None | |
| self.logger.info(f"HTTP client cleaned up for {self.config.name}") | |
| # ========================================================================= | |
| # RATE LIMITING | |
| # ========================================================================= | |
| async def _check_rate_limit(self): | |
| """Check and enforce rate limiting.""" | |
| now = datetime.now() | |
| # Reset tokens if minute has passed | |
| if now >= self._rate_limit_reset: | |
| self._rate_limit_tokens = self.config.rate_limit_per_minute | |
| self._rate_limit_reset = now + timedelta(minutes=1) | |
| self.logger.debug(f"Rate limit reset for {self.config.name}") | |
| # Check if we have tokens available | |
| if self._rate_limit_tokens <= 0: | |
| wait_time = (self._rate_limit_reset - now).total_seconds() | |
| self.logger.warning(f"Rate limit exceeded for {self.config.name}, waiting {wait_time:.1f}s") | |
| await asyncio.sleep(wait_time) | |
| await self._check_rate_limit() # Recursive check after wait | |
| # Consume a token | |
| self._rate_limit_tokens -= 1 | |
| self.logger.debug(f"Rate limit: {self._rate_limit_tokens} tokens remaining for {self.config.name}") | |
| # ========================================================================= | |
| # HTTP REQUEST METHODS | |
| # ========================================================================= | |
| async def _make_request( | |
| self, | |
| method: str, | |
| endpoint: str, | |
| data: Optional[Dict[str, Any]] = None, | |
| params: Optional[Dict[str, Any]] = None, | |
| headers: Optional[Dict[str, str]] = None | |
| ) -> APIResponse: | |
| """ | |
| Make an HTTP request with proper error handling and retry logic. | |
| This demonstrates async/await patterns and robust error handling. | |
| """ | |
| start_time = datetime.now() | |
| # Rate limiting check | |
| await self._check_rate_limit() | |
| # Prepare request | |
| request_headers = {} | |
| if headers: | |
| request_headers.update(headers) | |
| # Retry logic | |
| for attempt in range(self.config.max_retries + 1): | |
| try: | |
| async with self.get_client() as client: | |
| self.logger.debug(f"{self.config.name}: {method} {endpoint} (attempt {attempt + 1})") | |
| # Make the request | |
| response = await client.request( | |
| method=method, | |
| url=endpoint, | |
| json=data, | |
| params=params, | |
| headers=request_headers | |
| ) | |
| # Calculate response time | |
| response_time = (datetime.now() - start_time).total_seconds() * 1000 | |
| # Update request tracking | |
| self._request_count += 1 | |
| self._last_request_time = datetime.now() | |
| # Create standardized response | |
| api_response = APIResponse( | |
| success=response.status_code < 400, | |
| data=response.json() if response.status_code < 400 else None, | |
| error=response.text if response.status_code >= 400 else None, | |
| status_code=response.status_code, | |
| response_time_ms=response_time | |
| ) | |
| if api_response.success: | |
| self.logger.info(f"{self.config.name}: {method} {endpoint} - {response.status_code} ({response_time:.1f}ms)") | |
| else: | |
| self.logger.warning(f"{self.config.name}: {method} {endpoint} - {response.status_code} ({response_time:.1f}ms)") | |
| return api_response | |
| except httpx.TimeoutException as e: | |
| self.logger.warning(f"{self.config.name}: Timeout on {method} {endpoint} (attempt {attempt + 1})") | |
| if attempt == self.config.max_retries: | |
| return APIResponse( | |
| success=False, | |
| error=f"Request timeout after {self.config.max_retries + 1} attempts", | |
| status_code=408, | |
| response_time_ms=(datetime.now() - start_time).total_seconds() * 1000 | |
| ) | |
| except httpx.HTTPStatusError as e: | |
| self.logger.warning(f"{self.config.name}: HTTP error {e.response.status_code} on {method} {endpoint}") | |
| if attempt == self.config.max_retries: | |
| return APIResponse( | |
| success=False, | |
| error=f"HTTP error: {e.response.status_code}", | |
| status_code=e.response.status_code, | |
| response_time_ms=(datetime.now() - start_time).total_seconds() * 1000 | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"{self.config.name}: Unexpected error on {method} {endpoint}: {e}") | |
| if attempt == self.config.max_retries: | |
| return APIResponse( | |
| success=False, | |
| error=f"Unexpected error: {str(e)}", | |
| status_code=500, | |
| response_time_ms=(datetime.now() - start_time).total_seconds() * 1000 | |
| ) | |
| # Wait before retry | |
| if attempt < self.config.max_retries: | |
| wait_time = self.config.retry_delay * (2 ** attempt) # Exponential backoff | |
| self.logger.debug(f"{self.config.name}: Waiting {wait_time}s before retry") | |
| await asyncio.sleep(wait_time) | |
| # This should never be reached, but just in case | |
| return APIResponse( | |
| success=False, | |
| error="Max retries exceeded", | |
| status_code=500, | |
| response_time_ms=(datetime.now() - start_time).total_seconds() * 1000 | |
| ) | |
| # ========================================================================= | |
| # CONVENIENCE METHODS | |
| # ========================================================================= | |
| async def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> APIResponse: | |
| """Make a GET request.""" | |
| return await self._make_request("GET", endpoint, params=params) | |
| async def post(self, endpoint: str, data: Optional[Dict[str, Any]] = None) -> APIResponse: | |
| """Make a POST request.""" | |
| return await self._make_request("POST", endpoint, data=data) | |
| async def put(self, endpoint: str, data: Optional[Dict[str, Any]] = None) -> APIResponse: | |
| """Make a PUT request.""" | |
| return await self._make_request("PUT", endpoint, data=data) | |
| async def delete(self, endpoint: str) -> APIResponse: | |
| """Make a DELETE request.""" | |
| return await self._make_request("DELETE", endpoint) | |
| # ========================================================================= | |
| # ABSTRACT METHODS (Must be implemented by subclasses) | |
| # ========================================================================= | |
| async def search(self, query: Dict[str, Any]) -> APIResponse: | |
| """ | |
| Search for data using the agent's specific API. | |
| This method must be implemented by each specialized agent. | |
| """ | |
| pass | |
| async def get_details(self, item_id: str) -> APIResponse: | |
| """ | |
| Get detailed information about a specific item. | |
| This method must be implemented by each specialized agent. | |
| """ | |
| pass | |
| # ========================================================================= | |
| # UTILITY METHODS | |
| # ========================================================================= | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get agent statistics.""" | |
| return { | |
| "name": self.config.name, | |
| "request_count": self._request_count, | |
| "last_request_time": self._last_request_time, | |
| "rate_limit_tokens": self._rate_limit_tokens, | |
| "rate_limit_reset": self._rate_limit_reset | |
| } | |
| async def health_check(self) -> bool: | |
| """Check if the agent is healthy and can make API calls.""" | |
| try: | |
| # Try a simple request to check connectivity | |
| response = await self.get("/health") | |
| return response.success | |
| except Exception as e: | |
| self.logger.error(f"Health check failed for {self.config.name}: {e}") | |
| return False | |
| # ========================================================================= | |
| # ERROR HANDLING | |
| # ========================================================================= | |
| def _handle_error(self, error: Exception, context: str = "") -> ErrorResponse: | |
| """Handle errors and create standardized error responses.""" | |
| error_msg = f"{self.config.name} agent error" | |
| if context: | |
| error_msg += f" in {context}" | |
| self.logger.error(f"{error_msg}: {error}") | |
| return ErrorResponse( | |
| error_type=ErrorType.API_ERROR, | |
| error_code="AGENT_ERROR", | |
| message=error_msg, | |
| details=str(error), | |
| request_id=f"{self.config.name}_{datetime.now().timestamp()}" | |
| ) | |
| # ========================================================================= | |
| # CLASS METHODS | |
| # ========================================================================= | |
| def create_config( | |
| cls, | |
| name: str, | |
| api_base_url: str, | |
| api_key: str, | |
| **kwargs | |
| ) -> AgentConfig: | |
| """Create agent configuration with defaults.""" | |
| return AgentConfig( | |
| name=name, | |
| api_base_url=api_base_url, | |
| api_key=api_key, | |
| **kwargs | |
| ) | |
| # ========================================================================= | |
| # STRING REPRESENTATION | |
| # ========================================================================= | |
| def __str__(self) -> str: | |
| return f"{self.__class__.__name__}(name={self.config.name})" | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(config={self.config})" | |