""" Test script to verify comprehensive error handling implementation. This script tests: 1. Backend structured error responses with error codes 2. Correct HTTP status codes (401, 429, 503, 500) 3. Error classification (missing API key, invalid key, rate limit, provider error) """ import sys import os # Add backend src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from core.exceptions import ( classify_ai_error, RateLimitExceededException, APIKeyMissingException, APIKeyInvalidException, ProviderUnavailableException, ProviderErrorException ) from schemas.error import ErrorCode def test_error_classification(): """Test that errors are correctly classified.""" print("=" * 60) print("Testing Error Classification") print("=" * 60) # Test 1: Rate limit error print("\n1. Testing rate limit error classification:") rate_limit_error = Exception("Error: 429 rate limit exceeded") classified = classify_ai_error(rate_limit_error, provider="gemini") assert isinstance(classified, RateLimitExceededException) assert classified.error_code == ErrorCode.RATE_LIMIT_EXCEEDED assert classified.status_code == 429 assert classified.provider == "gemini" print(f" [OK] Correctly classified as RateLimitExceededException") print(f" [OK] Error code: {classified.error_code}") print(f" [OK] Status code: {classified.status_code}") print(f" [OK] Detail: {classified.detail}") # Test 2: API key missing error print("\n2. Testing API key missing error classification:") missing_key_error = Exception("API key not found") classified = classify_ai_error(missing_key_error, provider="openrouter") assert isinstance(classified, APIKeyMissingException) assert classified.error_code == ErrorCode.API_KEY_MISSING assert classified.status_code == 503 print(f" [OK] Correctly classified as APIKeyMissingException") print(f" [OK] Error code: {classified.error_code}") print(f" [OK] Status code: {classified.status_code}") # Test 3: API key invalid error print("\n3. Testing API key invalid error classification:") invalid_key_error = Exception("401 unauthorized - invalid api key") classified = classify_ai_error(invalid_key_error, provider="cohere") assert isinstance(classified, APIKeyInvalidException) assert classified.error_code == ErrorCode.API_KEY_INVALID assert classified.status_code == 401 print(f" [OK] Correctly classified as APIKeyInvalidException") print(f" [OK] Error code: {classified.error_code}") print(f" [OK] Status code: {classified.status_code}") # Test 4: Provider unavailable error print("\n4. Testing provider unavailable error classification:") unavailable_error = Exception("503 service unavailable") classified = classify_ai_error(unavailable_error, provider="gemini") assert isinstance(classified, ProviderUnavailableException) assert classified.error_code == ErrorCode.PROVIDER_UNAVAILABLE assert classified.status_code == 503 print(f" [OK] Correctly classified as ProviderUnavailableException") print(f" [OK] Error code: {classified.error_code}") print(f" [OK] Status code: {classified.status_code}") # Test 5: Generic provider error print("\n5. Testing generic provider error classification:") generic_error = Exception("Something went wrong with the AI service") classified = classify_ai_error(generic_error, provider="gemini") assert isinstance(classified, ProviderErrorException) assert classified.error_code == ErrorCode.PROVIDER_ERROR assert classified.status_code == 500 print(f" [OK] Correctly classified as ProviderErrorException") print(f" [OK] Error code: {classified.error_code}") print(f" [OK] Status code: {classified.status_code}") print("\n" + "=" * 60) print("[OK] All error classification tests passed!") print("=" * 60) def test_error_response_structure(): """Test that error responses have correct structure.""" print("\n" + "=" * 60) print("Testing Error Response Structure") print("=" * 60) # Create a sample error error = RateLimitExceededException(provider="gemini") print("\n1. Testing error attributes:") print(f" [OK] error_code: {error.error_code}") print(f" [OK] detail: {error.detail}") print(f" [OK] source: {error.source}") print(f" [OK] provider: {error.provider}") print(f" [OK] status_code: {error.status_code}") # Verify all required attributes exist assert hasattr(error, 'error_code') assert hasattr(error, 'detail') assert hasattr(error, 'source') assert hasattr(error, 'provider') assert hasattr(error, 'status_code') # Verify source is correct assert error.source == "AI_PROVIDER" print("\n" + "=" * 60) print("[OK] All error response structure tests passed!") print("=" * 60) def test_error_codes(): """Test that all error codes are defined.""" print("\n" + "=" * 60) print("Testing Error Code Constants") print("=" * 60) print("\n1. AI Provider error codes:") print(f" [OK] RATE_LIMIT_EXCEEDED: {ErrorCode.RATE_LIMIT_EXCEEDED}") print(f" [OK] API_KEY_MISSING: {ErrorCode.API_KEY_MISSING}") print(f" [OK] API_KEY_INVALID: {ErrorCode.API_KEY_INVALID}") print(f" [OK] PROVIDER_UNAVAILABLE: {ErrorCode.PROVIDER_UNAVAILABLE}") print(f" [OK] PROVIDER_ERROR: {ErrorCode.PROVIDER_ERROR}") print("\n2. Authentication error codes:") print(f" [OK] UNAUTHORIZED: {ErrorCode.UNAUTHORIZED}") print(f" [OK] TOKEN_EXPIRED: {ErrorCode.TOKEN_EXPIRED}") print(f" [OK] TOKEN_INVALID: {ErrorCode.TOKEN_INVALID}") print("\n3. Validation error codes:") print(f" [OK] INVALID_INPUT: {ErrorCode.INVALID_INPUT}") print(f" [OK] MESSAGE_TOO_LONG: {ErrorCode.MESSAGE_TOO_LONG}") print(f" [OK] MESSAGE_EMPTY: {ErrorCode.MESSAGE_EMPTY}") print("\n4. Database error codes:") print(f" [OK] CONVERSATION_NOT_FOUND: {ErrorCode.CONVERSATION_NOT_FOUND}") print(f" [OK] DATABASE_ERROR: {ErrorCode.DATABASE_ERROR}") print("\n5. Internal error codes:") print(f" [OK] INTERNAL_ERROR: {ErrorCode.INTERNAL_ERROR}") print(f" [OK] UNKNOWN_ERROR: {ErrorCode.UNKNOWN_ERROR}") print("\n" + "=" * 60) print("[OK] All error codes are properly defined!") print("=" * 60) if __name__ == "__main__": try: test_error_classification() test_error_response_structure() test_error_codes() print("\n" + "=" * 60) print("[SUCCESS] ALL TESTS PASSED!") print("=" * 60) print("\nError handling implementation is working correctly:") print("[OK] Backend structured error responses with error codes") print("[OK] Correct HTTP status codes (401, 429, 503, 500)") print("[OK] Error classification (missing API key, invalid key, rate limit, provider error)") print("[OK] Clear error identification and user-friendly messages") print("=" * 60) except AssertionError as e: print(f"\n[ERROR] Test failed: {e}") sys.exit(1) except Exception as e: print(f"\n[ERROR] Unexpected error: {e}") import traceback traceback.print_exc() sys.exit(1)