File size: 9,353 Bytes
9f69f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ca2eb
a2f220d
8fa2ce6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""Unit tests for HuggingFace error handling utilities."""

from unittest.mock import patch

import pytest

from src.utils.hf_error_handler import (
    extract_error_details,
    get_fallback_models,
    get_user_friendly_error_message,
    log_token_info,
    should_retry_with_fallback,
    validate_hf_token,
)


class TestExtractErrorDetails:
    """Tests for extract_error_details function."""

    def test_extract_403_error(self) -> None:
        """Should extract 403 error details correctly."""
        error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
        details = extract_error_details(error)
        
        assert details["status_code"] == 403
        assert details["model_name"] == "Qwen/Qwen3-Next-80B-A3B-Thinking"
        assert details["body"] == "Forbidden"
        assert details["error_type"] == "http_403"
        assert details["is_auth_error"] is True
        assert details["is_model_error"] is False

    def test_extract_422_error(self) -> None:
        """Should extract 422 error details correctly."""
        error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
        details = extract_error_details(error)
        
        assert details["status_code"] == 422
        assert details["model_name"] == "meta-llama/Llama-3.1-70B-Instruct"
        assert details["body"] == "Unprocessable Entity"
        assert details["error_type"] == "http_422"
        assert details["is_auth_error"] is False
        assert details["is_model_error"] is True

    def test_extract_partial_error(self) -> None:
        """Should handle partial error information."""
        error = Exception("status_code: 500")
        details = extract_error_details(error)
        
        assert details["status_code"] == 500
        assert details["model_name"] is None
        assert details["body"] is None
        assert details["error_type"] == "http_500"
        assert details["is_auth_error"] is False
        assert details["is_model_error"] is False

    def test_extract_generic_error(self) -> None:
        """Should handle generic errors without status codes."""
        error = Exception("Something went wrong")
        details = extract_error_details(error)
        
        assert details["status_code"] is None
        assert details["model_name"] is None
        assert details["body"] is None
        assert details["error_type"] == "unknown"
        assert details["is_auth_error"] is False
        assert details["is_model_error"] is False


class TestGetUserFriendlyErrorMessage:
    """Tests for get_user_friendly_error_message function."""

    def test_403_error_message(self) -> None:
        """Should generate user-friendly 403 error message."""
        error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
        message = get_user_friendly_error_message(error)
        
        assert "Authentication Error" in message
        assert "inference-api" in message
        assert "Qwen/Qwen3-Next-80B-A3B-Thinking" in message
        assert "Forbidden" in message

    def test_422_error_message(self) -> None:
        """Should generate user-friendly 422 error message."""
        error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
        message = get_user_friendly_error_message(error)
        
        assert "Model Compatibility Error" in message
        assert "meta-llama/Llama-3.1-70B-Instruct" in message
        assert "Unprocessable Entity" in message

    def test_generic_error_message(self) -> None:
        """Should generate generic error message for unknown errors."""
        error = Exception("Something went wrong")
        message = get_user_friendly_error_message(error)
        
        assert "API Error" in message
        assert "Something went wrong" in message

    def test_error_message_with_model_name_param(self) -> None:
        """Should use provided model_name parameter when not in error."""
        error = Exception("status_code: 403, body: Forbidden")
        message = get_user_friendly_error_message(error, model_name="test-model")
        
        assert "test-model" in message


class TestValidateHfToken:
    """Tests for validate_hf_token function."""

    def test_valid_token(self) -> None:
        """Should validate a valid token."""
        token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
        is_valid, error_msg = validate_hf_token(token)
        
        assert is_valid is True
        assert error_msg is None

    def test_none_token(self) -> None:
        """Should reject None token."""
        is_valid, error_msg = validate_hf_token(None)
        
        assert is_valid is False
        assert "None or empty" in error_msg

    def test_empty_token(self) -> None:
        """Should reject empty token."""
        is_valid, error_msg = validate_hf_token("")
        
        assert is_valid is False
        assert "None or empty" in error_msg

    def test_non_string_token(self) -> None:
        """Should reject non-string token."""
        is_valid, error_msg = validate_hf_token(123)  # type: ignore[arg-type]
        
        assert is_valid is False
        assert "not a string" in error_msg

    def test_short_token(self) -> None:
        """Should reject token that's too short."""
        is_valid, error_msg = validate_hf_token("hf_123")
        
        assert is_valid is False
        assert "too short" in error_msg

    def test_oauth_token_format(self) -> None:
        """Should accept OAuth tokens (may not start with hf_)."""
        # OAuth tokens may have different formats
        token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
        is_valid, error_msg = validate_hf_token(token)
        
        assert is_valid is True
        assert error_msg is None


class TestLogTokenInfo:
    """Tests for log_token_info function."""

    @patch("src.utils.hf_error_handler.logger")
    def test_log_valid_token(self, mock_logger) -> None:
        """Should log token info for valid token."""
        token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
        log_token_info(token, context="test")
        
        mock_logger.debug.assert_called_once()
        call_args = mock_logger.debug.call_args
        assert call_args[0][0] == "Token validation"
        assert call_args[1]["context"] == "test"
        assert call_args[1]["has_token"] is True
        assert call_args[1]["is_valid"] is True
        assert call_args[1]["token_length"] == len(token)
        assert "token_prefix" in call_args[1]

    @patch("src.utils.hf_error_handler.logger")
    def test_log_none_token(self, mock_logger) -> None:
        """Should log None token info."""
        log_token_info(None, context="test")
        
        mock_logger.debug.assert_called_once()
        call_args = mock_logger.debug.call_args
        assert call_args[0][0] == "Token validation"
        assert call_args[1]["context"] == "test"
        assert call_args[1]["has_token"] is False


class TestShouldRetryWithFallback:
    """Tests for should_retry_with_fallback function."""

    def test_403_error_should_retry(self) -> None:
        """Should retry for 403 errors."""
        error = Exception("status_code: 403, model_name: test-model, body: Forbidden")
        assert should_retry_with_fallback(error) is True

    def test_422_error_should_retry(self) -> None:
        """Should retry for 422 errors."""
        error = Exception("status_code: 422, model_name: test-model, body: Unprocessable Entity")
        assert should_retry_with_fallback(error) is True

    def test_model_specific_error_should_retry(self) -> None:
        """Should retry for model-specific errors."""
        error = Exception("status_code: 500, model_name: test-model, body: Error")
        assert should_retry_with_fallback(error) is True

    def test_generic_error_should_not_retry(self) -> None:
        """Should not retry for generic errors without model info."""
        error = Exception("Something went wrong")
        assert should_retry_with_fallback(error) is False


class TestGetFallbackModels:
    """Tests for get_fallback_models function."""

    def test_get_fallback_models_default(self) -> None:
        """Should return default fallback models."""
        fallbacks = get_fallback_models()
        
        assert len(fallbacks) > 0
        assert "meta-llama/Llama-3.1-8B-Instruct" in fallbacks
        assert isinstance(fallbacks, list)

    def test_get_fallback_models_excludes_original(self) -> None:
        """Should exclude original model from fallbacks."""
        original = "meta-llama/Llama-3.1-8B-Instruct"
        fallbacks = get_fallback_models(original_model=original)
        
        assert original not in fallbacks
        assert len(fallbacks) > 0

    def test_get_fallback_models_with_unknown_original(self) -> None:
        """Should return all fallbacks if original is not in list."""
        original = "unknown/model"
        fallbacks = get_fallback_models(original_model=original)
        
        # Should still have all fallbacks since original is not in the list
        assert len(fallbacks) >= 3  # At least 3 fallback models