Spaces:
Sleeping
Sleeping
| """ | |
| API Testing and Debugging Module | |
| This module provides comprehensive testing and debugging tools for API connections, | |
| specifically designed for Anthropic Claude and Tavily search APIs. | |
| """ | |
| import asyncio | |
| import json | |
| import time | |
| from datetime import datetime, timedelta | |
| from typing import Any, Dict, List, Optional, Union, Tuple | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import httpx | |
| from pydantic import BaseModel, Field, ValidationError | |
| from ..utils.logging import get_logger | |
| from ..utils.security import ErrorResponse, ErrorType | |
| class APITestResult(BaseModel): | |
| """Standardized result for API tests.""" | |
| success: bool = Field(..., description="Whether the test passed") | |
| api_name: str = Field(..., description="Name of the API being tested") | |
| test_name: str = Field(..., description="Name of the specific test") | |
| response_time_ms: float = Field(..., description="Response time in milliseconds") | |
| status_code: Optional[int] = Field(None, description="HTTP status code") | |
| error_message: Optional[str] = Field(None, description="Error message if test failed") | |
| response_data: Optional[Dict[str, Any]] = Field(None, description="Response data if successful") | |
| timestamp: datetime = Field(default_factory=datetime.now, description="Test timestamp") | |
| debug_info: Optional[Dict[str, Any]] = Field(None, description="Additional debug information") | |
| class APIErrorType(Enum): | |
| """Types of API errors for proper handling.""" | |
| NETWORK_ERROR = "network_error" | |
| AUTHENTICATION_ERROR = "authentication_error" | |
| AUTHORIZATION_ERROR = "authorization_error" | |
| RATE_LIMIT_ERROR = "rate_limit_error" | |
| VALIDATION_ERROR = "validation_error" | |
| TIMEOUT_ERROR = "timeout_error" | |
| SERVER_ERROR = "server_error" | |
| UNKNOWN_ERROR = "unknown_error" | |
| class APIConfig: | |
| """Configuration for API testing.""" | |
| name: str | |
| base_url: str | |
| api_key: str | |
| timeout: int = 30 | |
| max_retries: int = 3 | |
| retry_delay: float = 1.0 | |
| rate_limit_per_minute: int = 60 | |
| class APITester: | |
| """ | |
| Comprehensive API testing and debugging tool. | |
| This class provides: | |
| - Async API connection testing | |
| - Proper error handling for different failure types | |
| - Response structure validation | |
| - Retry logic with exponential backoff | |
| - Step-by-step debugging capabilities | |
| """ | |
| def __init__(self, config: APIConfig): | |
| """Initialize the API tester.""" | |
| self.config = config | |
| self.logger = get_logger(f"api_tester.{config.name}") | |
| self._client: Optional[httpx.AsyncClient] = None | |
| self._request_count = 0 | |
| self._last_request_time: Optional[datetime] = None | |
| # Rate limiting | |
| self._rate_limit_tokens = config.rate_limit_per_minute | |
| self._rate_limit_reset = datetime.now() + timedelta(minutes=1) | |
| self.logger.info(f"Initialized API tester for {config.name}") | |
| async def __aenter__(self): | |
| """Async context manager entry.""" | |
| await self._setup_client() | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| """Async context manager exit with proper cleanup.""" | |
| await self._cleanup_client() | |
| if exc_type: | |
| self.logger.error(f"API tester for {self.config.name} exited with error: {exc_val}") | |
| else: | |
| self.logger.info(f"API tester for {self.config.name} completed successfully") | |
| # ========================================================================= | |
| # CLIENT MANAGEMENT | |
| # ========================================================================= | |
| async def _setup_client(self): | |
| """Setup HTTP client with proper configuration.""" | |
| if self._client: | |
| return | |
| self._client = httpx.AsyncClient( | |
| base_url=self.config.base_url, | |
| timeout=httpx.Timeout(self.config.timeout), | |
| headers={ | |
| "Authorization": f"Bearer {self.config.api_key}", | |
| "User-Agent": f"WanderlustAI-APITester/{self.config.name}", | |
| "Content-Type": "application/json" | |
| } | |
| ) | |
| self.logger.info(f"HTTP client setup for {self.config.name}") | |
| async def _cleanup_client(self): | |
| """Cleanup HTTP client resources.""" | |
| if self._client: | |
| await self._client.aclose() | |
| self._client = None | |
| self.logger.info(f"HTTP client cleaned up for {self.config.name}") | |
| # ========================================================================= | |
| # RATE LIMITING | |
| # ========================================================================= | |
| async def _check_rate_limit(self): | |
| """Check and enforce rate limiting.""" | |
| now = datetime.now() | |
| # Reset tokens if minute has passed | |
| if now >= self._rate_limit_reset: | |
| self._rate_limit_tokens = self.config.rate_limit_per_minute | |
| self._rate_limit_reset = now + timedelta(minutes=1) | |
| self.logger.debug(f"Rate limit reset for {self.config.name}") | |
| # Check if we have tokens available | |
| if self._rate_limit_tokens <= 0: | |
| wait_time = (self._rate_limit_reset - now).total_seconds() | |
| self.logger.warning(f"Rate limit exceeded for {self.config.name}, waiting {wait_time:.1f}s") | |
| await asyncio.sleep(wait_time) | |
| await self._check_rate_limit() # Recursive check after wait | |
| # Consume a token | |
| self._rate_limit_tokens -= 1 | |
| self.logger.debug(f"Rate limit: {self._rate_limit_tokens} tokens remaining for {self.config.name}") | |
| # ========================================================================= | |
| # ERROR HANDLING | |
| # ========================================================================= | |
| def _classify_error(self, error: Exception, status_code: Optional[int] = None) -> APIErrorType: | |
| """Classify API errors for proper handling.""" | |
| if isinstance(error, httpx.TimeoutException): | |
| return APIErrorType.TIMEOUT_ERROR | |
| elif isinstance(error, httpx.ConnectError): | |
| return APIErrorType.NETWORK_ERROR | |
| elif isinstance(error, httpx.HTTPStatusError): | |
| if status_code == 401: | |
| return APIErrorType.AUTHENTICATION_ERROR | |
| elif status_code == 403: | |
| return APIErrorType.AUTHORIZATION_ERROR | |
| elif status_code == 429: | |
| return APIErrorType.RATE_LIMIT_ERROR | |
| elif 500 <= status_code < 600: | |
| return APIErrorType.SERVER_ERROR | |
| else: | |
| return APIErrorType.VALIDATION_ERROR | |
| else: | |
| return APIErrorType.UNKNOWN_ERROR | |
| def _should_retry(self, error_type: APIErrorType, attempt: int) -> bool: | |
| """Determine if a request should be retried based on error type.""" | |
| if attempt >= self.config.max_retries: | |
| return False | |
| # Retry for these error types | |
| retryable_errors = { | |
| APIErrorType.NETWORK_ERROR, | |
| APIErrorType.TIMEOUT_ERROR, | |
| APIErrorType.SERVER_ERROR, | |
| APIErrorType.RATE_LIMIT_ERROR | |
| } | |
| return error_type in retryable_errors | |
| def _get_retry_delay(self, attempt: int, error_type: APIErrorType) -> float: | |
| """Calculate retry delay with exponential backoff.""" | |
| base_delay = self.config.retry_delay | |
| if error_type == APIErrorType.RATE_LIMIT_ERROR: | |
| # Longer delay for rate limit errors | |
| return base_delay * (2 ** attempt) * 2 | |
| else: | |
| # Standard exponential backoff | |
| return base_delay * (2 ** attempt) | |
| # ========================================================================= | |
| # HTTP REQUEST METHODS | |
| # ========================================================================= | |
| async def _make_request( | |
| self, | |
| method: str, | |
| endpoint: str, | |
| data: Optional[Dict[str, Any]] = None, | |
| params: Optional[Dict[str, Any]] = None, | |
| headers: Optional[Dict[str, str]] = None | |
| ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[int], float]: | |
| """ | |
| Make an HTTP request with proper error handling and retry logic. | |
| Returns: | |
| Tuple of (success, data, error_message, status_code, response_time_ms) | |
| """ | |
| start_time = time.time() | |
| # Rate limiting check | |
| await self._check_rate_limit() | |
| # Prepare request | |
| request_headers = {} | |
| if headers: | |
| request_headers.update(headers) | |
| # Retry logic | |
| for attempt in range(self.config.max_retries + 1): | |
| try: | |
| if not self._client: | |
| await self._setup_client() | |
| self.logger.debug(f"{self.config.name}: {method} {endpoint} (attempt {attempt + 1})") | |
| # Make the request | |
| response = await self._client.request( | |
| method=method, | |
| url=endpoint, | |
| json=data, | |
| params=params, | |
| headers=request_headers | |
| ) | |
| # Calculate response time | |
| response_time = (time.time() - start_time) * 1000 | |
| # Update request tracking | |
| self._request_count += 1 | |
| self._last_request_time = datetime.now() | |
| # Check if response is successful | |
| if response.status_code < 400: | |
| try: | |
| response_data = response.json() | |
| self.logger.info(f"{self.config.name}: {method} {endpoint} - {response.status_code} ({response_time:.1f}ms)") | |
| return True, response_data, None, response.status_code, response_time | |
| except json.JSONDecodeError as e: | |
| self.logger.warning(f"{self.config.name}: Invalid JSON response: {e}") | |
| return False, None, f"Invalid JSON response: {e}", response.status_code, response_time | |
| else: | |
| error_msg = f"HTTP {response.status_code}: {response.text}" | |
| self.logger.warning(f"{self.config.name}: {method} {endpoint} - {response.status_code} ({response_time:.1f}ms)") | |
| # Classify error and decide if we should retry | |
| error_type = self._classify_error(httpx.HTTPStatusError("", response=response), response.status_code) | |
| if self._should_retry(error_type, attempt): | |
| delay = self._get_retry_delay(attempt, error_type) | |
| self.logger.info(f"{self.config.name}: Retrying in {delay:.1f}s (error: {error_type.value})") | |
| await asyncio.sleep(delay) | |
| continue | |
| else: | |
| return False, None, error_msg, response.status_code, response_time | |
| except httpx.TimeoutException as e: | |
| error_type = APIErrorType.TIMEOUT_ERROR | |
| self.logger.warning(f"{self.config.name}: Timeout on {method} {endpoint} (attempt {attempt + 1})") | |
| if self._should_retry(error_type, attempt): | |
| delay = self._get_retry_delay(attempt, error_type) | |
| self.logger.info(f"{self.config.name}: Retrying in {delay:.1f}s (timeout)") | |
| await asyncio.sleep(delay) | |
| continue | |
| else: | |
| response_time = (time.time() - start_time) * 1000 | |
| return False, None, f"Request timeout after {self.config.max_retries + 1} attempts", 408, response_time | |
| except httpx.ConnectError as e: | |
| error_type = APIErrorType.NETWORK_ERROR | |
| self.logger.warning(f"{self.config.name}: Connection error on {method} {endpoint} (attempt {attempt + 1}): {e}") | |
| if self._should_retry(error_type, attempt): | |
| delay = self._get_retry_delay(attempt, error_type) | |
| self.logger.info(f"{self.config.name}: Retrying in {delay:.1f}s (connection error)") | |
| await asyncio.sleep(delay) | |
| continue | |
| else: | |
| response_time = (time.time() - start_time) * 1000 | |
| return False, None, f"Connection error: {e}", None, response_time | |
| except Exception as e: | |
| error_type = APIErrorType.UNKNOWN_ERROR | |
| self.logger.error(f"{self.config.name}: Unexpected error on {method} {endpoint}: {e}") | |
| if self._should_retry(error_type, attempt): | |
| delay = self._get_retry_delay(attempt, error_type) | |
| self.logger.info(f"{self.config.name}: Retrying in {delay:.1f}s (unexpected error)") | |
| await asyncio.sleep(delay) | |
| continue | |
| else: | |
| response_time = (time.time() - start_time) * 1000 | |
| return False, None, f"Unexpected error: {e}", None, response_time | |
| # This should never be reached, but just in case | |
| response_time = (time.time() - start_time) * 1000 | |
| return False, None, "Max retries exceeded", None, response_time | |
| # ========================================================================= | |
| # API TESTING METHODS | |
| # ========================================================================= | |
| async def test_connection(self) -> APITestResult: | |
| """Test basic API connection.""" | |
| test_name = "connection_test" | |
| self.logger.info(f"Testing {self.config.name} connection...") | |
| # Try a simple request to test connectivity | |
| success, data, error, status_code, response_time = await self._make_request("GET", "/health") | |
| if success: | |
| self.logger.info(f"β {self.config.name} connection test passed") | |
| return APITestResult( | |
| success=True, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| response_data=data, | |
| debug_info={ | |
| "request_count": self._request_count, | |
| "last_request_time": self._last_request_time | |
| } | |
| ) | |
| else: | |
| self.logger.error(f"β {self.config.name} connection test failed: {error}") | |
| return APITestResult( | |
| success=False, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| error_message=error, | |
| debug_info={ | |
| "request_count": self._request_count, | |
| "last_request_time": self._last_request_time | |
| } | |
| ) | |
| async def test_authentication(self) -> APITestResult: | |
| """Test API authentication.""" | |
| test_name = "authentication_test" | |
| self.logger.info(f"Testing {self.config.name} authentication...") | |
| # Try a request that requires authentication | |
| success, data, error, status_code, response_time = await self._make_request("GET", "/auth/test") | |
| if success: | |
| self.logger.info(f"β {self.config.name} authentication test passed") | |
| return APITestResult( | |
| success=True, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| response_data=data | |
| ) | |
| else: | |
| if status_code == 401: | |
| self.logger.error(f"β {self.config.name} authentication test failed: Invalid API key") | |
| return APITestResult( | |
| success=False, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| error_message="Invalid API key or authentication failed" | |
| ) | |
| else: | |
| self.logger.error(f"β {self.config.name} authentication test failed: {error}") | |
| return APITestResult( | |
| success=False, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| error_message=error | |
| ) | |
| async def test_rate_limits(self) -> APITestResult: | |
| """Test API rate limiting.""" | |
| test_name = "rate_limit_test" | |
| self.logger.info(f"Testing {self.config.name} rate limits...") | |
| # Make multiple rapid requests to test rate limiting | |
| start_time = time.time() | |
| success_count = 0 | |
| rate_limit_hit = False | |
| for i in range(10): # Make 10 rapid requests | |
| success, data, error, status_code, response_time = await self._make_request("GET", "/test") | |
| if success: | |
| success_count += 1 | |
| elif status_code == 429: | |
| rate_limit_hit = True | |
| break | |
| total_time = (time.time() - start_time) * 1000 | |
| if rate_limit_hit: | |
| self.logger.info(f"β {self.config.name} rate limit test passed (rate limit detected)") | |
| return APITestResult( | |
| success=True, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=total_time, | |
| status_code=429, | |
| response_data={"rate_limit_detected": True, "successful_requests": success_count} | |
| ) | |
| else: | |
| self.logger.warning(f"β οΈ {self.config.name} rate limit test inconclusive (no rate limit hit)") | |
| return APITestResult( | |
| success=True, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=total_time, | |
| response_data={"rate_limit_detected": False, "successful_requests": success_count} | |
| ) | |
| # ========================================================================= | |
| # RESPONSE VALIDATION | |
| # ========================================================================= | |
| def validate_response_structure(self, data: Dict[str, Any], expected_fields: List[str]) -> Tuple[bool, List[str]]: | |
| """ | |
| Validate API response structure. | |
| Args: | |
| data: Response data to validate | |
| expected_fields: List of expected field names | |
| Returns: | |
| Tuple of (is_valid, missing_fields) | |
| """ | |
| if not isinstance(data, dict): | |
| return False, ["Response is not a dictionary"] | |
| missing_fields = [] | |
| for field in expected_fields: | |
| if field not in data: | |
| missing_fields.append(field) | |
| return len(missing_fields) == 0, missing_fields | |
| async def test_response_structure(self, endpoint: str, expected_fields: List[str]) -> APITestResult: | |
| """Test API response structure validation.""" | |
| test_name = "response_structure_test" | |
| self.logger.info(f"Testing {self.config.name} response structure for {endpoint}...") | |
| success, data, error, status_code, response_time = await self._make_request("GET", endpoint) | |
| if success and data: | |
| is_valid, missing_fields = self.validate_response_structure(data, expected_fields) | |
| if is_valid: | |
| self.logger.info(f"β {self.config.name} response structure test passed") | |
| return APITestResult( | |
| success=True, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| response_data=data, | |
| debug_info={"expected_fields": expected_fields, "missing_fields": []} | |
| ) | |
| else: | |
| self.logger.error(f"β {self.config.name} response structure test failed: missing fields {missing_fields}") | |
| return APITestResult( | |
| success=False, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| error_message=f"Missing required fields: {missing_fields}", | |
| response_data=data, | |
| debug_info={"expected_fields": expected_fields, "missing_fields": missing_fields} | |
| ) | |
| else: | |
| self.logger.error(f"β {self.config.name} response structure test failed: {error}") | |
| return APITestResult( | |
| success=False, | |
| api_name=self.config.name, | |
| test_name=test_name, | |
| response_time_ms=response_time, | |
| status_code=status_code, | |
| error_message=error | |
| ) | |
| # ========================================================================= | |
| # COMPREHENSIVE TESTING | |
| # ========================================================================= | |
| async def run_all_tests(self) -> List[APITestResult]: | |
| """Run all API tests and return results.""" | |
| self.logger.info(f"Running comprehensive tests for {self.config.name}...") | |
| tests = [ | |
| self.test_connection(), | |
| self.test_authentication(), | |
| self.test_rate_limits() | |
| ] | |
| results = await asyncio.gather(*tests, return_exceptions=True) | |
| # Handle any exceptions | |
| final_results = [] | |
| for i, result in enumerate(results): | |
| if isinstance(result, Exception): | |
| self.logger.error(f"Test {i} failed with exception: {result}") | |
| final_results.append(APITestResult( | |
| success=False, | |
| api_name=self.config.name, | |
| test_name=f"test_{i}", | |
| response_time_ms=0, | |
| error_message=f"Test failed with exception: {result}" | |
| )) | |
| else: | |
| final_results.append(result) | |
| return final_results | |
| # ========================================================================= | |
| # UTILITY METHODS | |
| # ========================================================================= | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get API tester statistics.""" | |
| return { | |
| "api_name": self.config.name, | |
| "request_count": self._request_count, | |
| "last_request_time": self._last_request_time, | |
| "rate_limit_tokens": self._rate_limit_tokens, | |
| "rate_limit_reset": self._rate_limit_reset | |
| } | |
| def print_test_summary(self, results: List[APITestResult]): | |
| """Print a summary of test results.""" | |
| print(f"\nπ API Test Summary for {self.config.name}") | |
| print("=" * 50) | |
| passed = sum(1 for r in results if r.success) | |
| total = len(results) | |
| print(f"Tests Passed: {passed}/{total}") | |
| print(f"Success Rate: {(passed/total)*100:.1f}%") | |
| for result in results: | |
| status = "β PASS" if result.success else "β FAIL" | |
| print(f" {status} {result.test_name} ({result.response_time_ms:.1f}ms)") | |
| if not result.success and result.error_message: | |
| print(f" Error: {result.error_message}") | |
| print() | |