|
|
"""
|
|
|
Test script to verify comprehensive error handling implementation.
|
|
|
|
|
|
This script tests:
|
|
|
1. Backend structured error responses with error codes
|
|
|
2. Correct HTTP status codes (401, 429, 503, 500)
|
|
|
3. Error classification (missing API key, invalid key, rate limit, provider error)
|
|
|
"""
|
|
|
|
|
|
import sys
|
|
|
import os
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
|
|
|
|
|
from core.exceptions import (
|
|
|
classify_ai_error,
|
|
|
RateLimitExceededException,
|
|
|
APIKeyMissingException,
|
|
|
APIKeyInvalidException,
|
|
|
ProviderUnavailableException,
|
|
|
ProviderErrorException
|
|
|
)
|
|
|
from schemas.error import ErrorCode
|
|
|
|
|
|
|
|
|
def test_error_classification():
|
|
|
"""Test that errors are correctly classified."""
|
|
|
print("=" * 60)
|
|
|
print("Testing Error Classification")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
print("\n1. Testing rate limit error classification:")
|
|
|
rate_limit_error = Exception("Error: 429 rate limit exceeded")
|
|
|
classified = classify_ai_error(rate_limit_error, provider="gemini")
|
|
|
assert isinstance(classified, RateLimitExceededException)
|
|
|
assert classified.error_code == ErrorCode.RATE_LIMIT_EXCEEDED
|
|
|
assert classified.status_code == 429
|
|
|
assert classified.provider == "gemini"
|
|
|
print(f" [OK] Correctly classified as RateLimitExceededException")
|
|
|
print(f" [OK] Error code: {classified.error_code}")
|
|
|
print(f" [OK] Status code: {classified.status_code}")
|
|
|
print(f" [OK] Detail: {classified.detail}")
|
|
|
|
|
|
|
|
|
print("\n2. Testing API key missing error classification:")
|
|
|
missing_key_error = Exception("API key not found")
|
|
|
classified = classify_ai_error(missing_key_error, provider="openrouter")
|
|
|
assert isinstance(classified, APIKeyMissingException)
|
|
|
assert classified.error_code == ErrorCode.API_KEY_MISSING
|
|
|
assert classified.status_code == 503
|
|
|
print(f" [OK] Correctly classified as APIKeyMissingException")
|
|
|
print(f" [OK] Error code: {classified.error_code}")
|
|
|
print(f" [OK] Status code: {classified.status_code}")
|
|
|
|
|
|
|
|
|
print("\n3. Testing API key invalid error classification:")
|
|
|
invalid_key_error = Exception("401 unauthorized - invalid api key")
|
|
|
classified = classify_ai_error(invalid_key_error, provider="cohere")
|
|
|
assert isinstance(classified, APIKeyInvalidException)
|
|
|
assert classified.error_code == ErrorCode.API_KEY_INVALID
|
|
|
assert classified.status_code == 401
|
|
|
print(f" [OK] Correctly classified as APIKeyInvalidException")
|
|
|
print(f" [OK] Error code: {classified.error_code}")
|
|
|
print(f" [OK] Status code: {classified.status_code}")
|
|
|
|
|
|
|
|
|
print("\n4. Testing provider unavailable error classification:")
|
|
|
unavailable_error = Exception("503 service unavailable")
|
|
|
classified = classify_ai_error(unavailable_error, provider="gemini")
|
|
|
assert isinstance(classified, ProviderUnavailableException)
|
|
|
assert classified.error_code == ErrorCode.PROVIDER_UNAVAILABLE
|
|
|
assert classified.status_code == 503
|
|
|
print(f" [OK] Correctly classified as ProviderUnavailableException")
|
|
|
print(f" [OK] Error code: {classified.error_code}")
|
|
|
print(f" [OK] Status code: {classified.status_code}")
|
|
|
|
|
|
|
|
|
print("\n5. Testing generic provider error classification:")
|
|
|
generic_error = Exception("Something went wrong with the AI service")
|
|
|
classified = classify_ai_error(generic_error, provider="gemini")
|
|
|
assert isinstance(classified, ProviderErrorException)
|
|
|
assert classified.error_code == ErrorCode.PROVIDER_ERROR
|
|
|
assert classified.status_code == 500
|
|
|
print(f" [OK] Correctly classified as ProviderErrorException")
|
|
|
print(f" [OK] Error code: {classified.error_code}")
|
|
|
print(f" [OK] Status code: {classified.status_code}")
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("[OK] All error classification tests passed!")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
def test_error_response_structure():
|
|
|
"""Test that error responses have correct structure."""
|
|
|
print("\n" + "=" * 60)
|
|
|
print("Testing Error Response Structure")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
error = RateLimitExceededException(provider="gemini")
|
|
|
|
|
|
print("\n1. Testing error attributes:")
|
|
|
print(f" [OK] error_code: {error.error_code}")
|
|
|
print(f" [OK] detail: {error.detail}")
|
|
|
print(f" [OK] source: {error.source}")
|
|
|
print(f" [OK] provider: {error.provider}")
|
|
|
print(f" [OK] status_code: {error.status_code}")
|
|
|
|
|
|
|
|
|
assert hasattr(error, 'error_code')
|
|
|
assert hasattr(error, 'detail')
|
|
|
assert hasattr(error, 'source')
|
|
|
assert hasattr(error, 'provider')
|
|
|
assert hasattr(error, 'status_code')
|
|
|
|
|
|
|
|
|
assert error.source == "AI_PROVIDER"
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("[OK] All error response structure tests passed!")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
def test_error_codes():
|
|
|
"""Test that all error codes are defined."""
|
|
|
print("\n" + "=" * 60)
|
|
|
print("Testing Error Code Constants")
|
|
|
print("=" * 60)
|
|
|
|
|
|
print("\n1. AI Provider error codes:")
|
|
|
print(f" [OK] RATE_LIMIT_EXCEEDED: {ErrorCode.RATE_LIMIT_EXCEEDED}")
|
|
|
print(f" [OK] API_KEY_MISSING: {ErrorCode.API_KEY_MISSING}")
|
|
|
print(f" [OK] API_KEY_INVALID: {ErrorCode.API_KEY_INVALID}")
|
|
|
print(f" [OK] PROVIDER_UNAVAILABLE: {ErrorCode.PROVIDER_UNAVAILABLE}")
|
|
|
print(f" [OK] PROVIDER_ERROR: {ErrorCode.PROVIDER_ERROR}")
|
|
|
|
|
|
print("\n2. Authentication error codes:")
|
|
|
print(f" [OK] UNAUTHORIZED: {ErrorCode.UNAUTHORIZED}")
|
|
|
print(f" [OK] TOKEN_EXPIRED: {ErrorCode.TOKEN_EXPIRED}")
|
|
|
print(f" [OK] TOKEN_INVALID: {ErrorCode.TOKEN_INVALID}")
|
|
|
|
|
|
print("\n3. Validation error codes:")
|
|
|
print(f" [OK] INVALID_INPUT: {ErrorCode.INVALID_INPUT}")
|
|
|
print(f" [OK] MESSAGE_TOO_LONG: {ErrorCode.MESSAGE_TOO_LONG}")
|
|
|
print(f" [OK] MESSAGE_EMPTY: {ErrorCode.MESSAGE_EMPTY}")
|
|
|
|
|
|
print("\n4. Database error codes:")
|
|
|
print(f" [OK] CONVERSATION_NOT_FOUND: {ErrorCode.CONVERSATION_NOT_FOUND}")
|
|
|
print(f" [OK] DATABASE_ERROR: {ErrorCode.DATABASE_ERROR}")
|
|
|
|
|
|
print("\n5. Internal error codes:")
|
|
|
print(f" [OK] INTERNAL_ERROR: {ErrorCode.INTERNAL_ERROR}")
|
|
|
print(f" [OK] UNKNOWN_ERROR: {ErrorCode.UNKNOWN_ERROR}")
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("[OK] All error codes are properly defined!")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
try:
|
|
|
test_error_classification()
|
|
|
test_error_response_structure()
|
|
|
test_error_codes()
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("[SUCCESS] ALL TESTS PASSED!")
|
|
|
print("=" * 60)
|
|
|
print("\nError handling implementation is working correctly:")
|
|
|
print("[OK] Backend structured error responses with error codes")
|
|
|
print("[OK] Correct HTTP status codes (401, 429, 503, 500)")
|
|
|
print("[OK] Error classification (missing API key, invalid key, rate limit, provider error)")
|
|
|
print("[OK] Clear error identification and user-friendly messages")
|
|
|
print("=" * 60)
|
|
|
|
|
|
except AssertionError as e:
|
|
|
print(f"\n[ERROR] Test failed: {e}")
|
|
|
sys.exit(1)
|
|
|
except Exception as e:
|
|
|
print(f"\n[ERROR] Unexpected error: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
sys.exit(1)
|
|
|
|