File size: 7,548 Bytes
676582c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""

Test script to verify comprehensive error handling implementation.



This script tests:

1. Backend structured error responses with error codes

2. Correct HTTP status codes (401, 429, 503, 500)

3. Error classification (missing API key, invalid key, rate limit, provider error)

"""

import sys
import os

# Add backend src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))

from core.exceptions import (
    classify_ai_error,
    RateLimitExceededException,
    APIKeyMissingException,
    APIKeyInvalidException,
    ProviderUnavailableException,
    ProviderErrorException
)
from schemas.error import ErrorCode


def test_error_classification():
    """Test that errors are correctly classified."""
    print("=" * 60)
    print("Testing Error Classification")
    print("=" * 60)

    # Test 1: Rate limit error
    print("\n1. Testing rate limit error classification:")
    rate_limit_error = Exception("Error: 429 rate limit exceeded")
    classified = classify_ai_error(rate_limit_error, provider="gemini")
    assert isinstance(classified, RateLimitExceededException)
    assert classified.error_code == ErrorCode.RATE_LIMIT_EXCEEDED
    assert classified.status_code == 429
    assert classified.provider == "gemini"
    print(f"   [OK] Correctly classified as RateLimitExceededException")
    print(f"   [OK] Error code: {classified.error_code}")
    print(f"   [OK] Status code: {classified.status_code}")
    print(f"   [OK] Detail: {classified.detail}")

    # Test 2: API key missing error
    print("\n2. Testing API key missing error classification:")
    missing_key_error = Exception("API key not found")
    classified = classify_ai_error(missing_key_error, provider="openrouter")
    assert isinstance(classified, APIKeyMissingException)
    assert classified.error_code == ErrorCode.API_KEY_MISSING
    assert classified.status_code == 503
    print(f"   [OK] Correctly classified as APIKeyMissingException")
    print(f"   [OK] Error code: {classified.error_code}")
    print(f"   [OK] Status code: {classified.status_code}")

    # Test 3: API key invalid error
    print("\n3. Testing API key invalid error classification:")
    invalid_key_error = Exception("401 unauthorized - invalid api key")
    classified = classify_ai_error(invalid_key_error, provider="cohere")
    assert isinstance(classified, APIKeyInvalidException)
    assert classified.error_code == ErrorCode.API_KEY_INVALID
    assert classified.status_code == 401
    print(f"   [OK] Correctly classified as APIKeyInvalidException")
    print(f"   [OK] Error code: {classified.error_code}")
    print(f"   [OK] Status code: {classified.status_code}")

    # Test 4: Provider unavailable error
    print("\n4. Testing provider unavailable error classification:")
    unavailable_error = Exception("503 service unavailable")
    classified = classify_ai_error(unavailable_error, provider="gemini")
    assert isinstance(classified, ProviderUnavailableException)
    assert classified.error_code == ErrorCode.PROVIDER_UNAVAILABLE
    assert classified.status_code == 503
    print(f"   [OK] Correctly classified as ProviderUnavailableException")
    print(f"   [OK] Error code: {classified.error_code}")
    print(f"   [OK] Status code: {classified.status_code}")

    # Test 5: Generic provider error
    print("\n5. Testing generic provider error classification:")
    generic_error = Exception("Something went wrong with the AI service")
    classified = classify_ai_error(generic_error, provider="gemini")
    assert isinstance(classified, ProviderErrorException)
    assert classified.error_code == ErrorCode.PROVIDER_ERROR
    assert classified.status_code == 500
    print(f"   [OK] Correctly classified as ProviderErrorException")
    print(f"   [OK] Error code: {classified.error_code}")
    print(f"   [OK] Status code: {classified.status_code}")

    print("\n" + "=" * 60)
    print("[OK] All error classification tests passed!")
    print("=" * 60)


def test_error_response_structure():
    """Test that error responses have correct structure."""
    print("\n" + "=" * 60)
    print("Testing Error Response Structure")
    print("=" * 60)

    # Create a sample error
    error = RateLimitExceededException(provider="gemini")

    print("\n1. Testing error attributes:")
    print(f"   [OK] error_code: {error.error_code}")
    print(f"   [OK] detail: {error.detail}")
    print(f"   [OK] source: {error.source}")
    print(f"   [OK] provider: {error.provider}")
    print(f"   [OK] status_code: {error.status_code}")

    # Verify all required attributes exist
    assert hasattr(error, 'error_code')
    assert hasattr(error, 'detail')
    assert hasattr(error, 'source')
    assert hasattr(error, 'provider')
    assert hasattr(error, 'status_code')

    # Verify source is correct
    assert error.source == "AI_PROVIDER"

    print("\n" + "=" * 60)
    print("[OK] All error response structure tests passed!")
    print("=" * 60)


def test_error_codes():
    """Test that all error codes are defined."""
    print("\n" + "=" * 60)
    print("Testing Error Code Constants")
    print("=" * 60)

    print("\n1. AI Provider error codes:")
    print(f"   [OK] RATE_LIMIT_EXCEEDED: {ErrorCode.RATE_LIMIT_EXCEEDED}")
    print(f"   [OK] API_KEY_MISSING: {ErrorCode.API_KEY_MISSING}")
    print(f"   [OK] API_KEY_INVALID: {ErrorCode.API_KEY_INVALID}")
    print(f"   [OK] PROVIDER_UNAVAILABLE: {ErrorCode.PROVIDER_UNAVAILABLE}")
    print(f"   [OK] PROVIDER_ERROR: {ErrorCode.PROVIDER_ERROR}")

    print("\n2. Authentication error codes:")
    print(f"   [OK] UNAUTHORIZED: {ErrorCode.UNAUTHORIZED}")
    print(f"   [OK] TOKEN_EXPIRED: {ErrorCode.TOKEN_EXPIRED}")
    print(f"   [OK] TOKEN_INVALID: {ErrorCode.TOKEN_INVALID}")

    print("\n3. Validation error codes:")
    print(f"   [OK] INVALID_INPUT: {ErrorCode.INVALID_INPUT}")
    print(f"   [OK] MESSAGE_TOO_LONG: {ErrorCode.MESSAGE_TOO_LONG}")
    print(f"   [OK] MESSAGE_EMPTY: {ErrorCode.MESSAGE_EMPTY}")

    print("\n4. Database error codes:")
    print(f"   [OK] CONVERSATION_NOT_FOUND: {ErrorCode.CONVERSATION_NOT_FOUND}")
    print(f"   [OK] DATABASE_ERROR: {ErrorCode.DATABASE_ERROR}")

    print("\n5. Internal error codes:")
    print(f"   [OK] INTERNAL_ERROR: {ErrorCode.INTERNAL_ERROR}")
    print(f"   [OK] UNKNOWN_ERROR: {ErrorCode.UNKNOWN_ERROR}")

    print("\n" + "=" * 60)
    print("[OK] All error codes are properly defined!")
    print("=" * 60)


if __name__ == "__main__":
    try:
        test_error_classification()
        test_error_response_structure()
        test_error_codes()

        print("\n" + "=" * 60)
        print("[SUCCESS] ALL TESTS PASSED!")
        print("=" * 60)
        print("\nError handling implementation is working correctly:")
        print("[OK] Backend structured error responses with error codes")
        print("[OK] Correct HTTP status codes (401, 429, 503, 500)")
        print("[OK] Error classification (missing API key, invalid key, rate limit, provider error)")
        print("[OK] Clear error identification and user-friendly messages")
        print("=" * 60)

    except AssertionError as e:
        print(f"\n[ERROR] Test failed: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"\n[ERROR] Unexpected error: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)