Spaces:
Running
Running
File size: 9,353 Bytes
9f69f35 71ca2eb a2f220d 8fa2ce6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
"""Unit tests for HuggingFace error handling utilities."""
from unittest.mock import patch
import pytest
from src.utils.hf_error_handler import (
extract_error_details,
get_fallback_models,
get_user_friendly_error_message,
log_token_info,
should_retry_with_fallback,
validate_hf_token,
)
class TestExtractErrorDetails:
"""Tests for extract_error_details function."""
def test_extract_403_error(self) -> None:
"""Should extract 403 error details correctly."""
error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
details = extract_error_details(error)
assert details["status_code"] == 403
assert details["model_name"] == "Qwen/Qwen3-Next-80B-A3B-Thinking"
assert details["body"] == "Forbidden"
assert details["error_type"] == "http_403"
assert details["is_auth_error"] is True
assert details["is_model_error"] is False
def test_extract_422_error(self) -> None:
"""Should extract 422 error details correctly."""
error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
details = extract_error_details(error)
assert details["status_code"] == 422
assert details["model_name"] == "meta-llama/Llama-3.1-70B-Instruct"
assert details["body"] == "Unprocessable Entity"
assert details["error_type"] == "http_422"
assert details["is_auth_error"] is False
assert details["is_model_error"] is True
def test_extract_partial_error(self) -> None:
"""Should handle partial error information."""
error = Exception("status_code: 500")
details = extract_error_details(error)
assert details["status_code"] == 500
assert details["model_name"] is None
assert details["body"] is None
assert details["error_type"] == "http_500"
assert details["is_auth_error"] is False
assert details["is_model_error"] is False
def test_extract_generic_error(self) -> None:
"""Should handle generic errors without status codes."""
error = Exception("Something went wrong")
details = extract_error_details(error)
assert details["status_code"] is None
assert details["model_name"] is None
assert details["body"] is None
assert details["error_type"] == "unknown"
assert details["is_auth_error"] is False
assert details["is_model_error"] is False
class TestGetUserFriendlyErrorMessage:
"""Tests for get_user_friendly_error_message function."""
def test_403_error_message(self) -> None:
"""Should generate user-friendly 403 error message."""
error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
message = get_user_friendly_error_message(error)
assert "Authentication Error" in message
assert "inference-api" in message
assert "Qwen/Qwen3-Next-80B-A3B-Thinking" in message
assert "Forbidden" in message
def test_422_error_message(self) -> None:
"""Should generate user-friendly 422 error message."""
error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
message = get_user_friendly_error_message(error)
assert "Model Compatibility Error" in message
assert "meta-llama/Llama-3.1-70B-Instruct" in message
assert "Unprocessable Entity" in message
def test_generic_error_message(self) -> None:
"""Should generate generic error message for unknown errors."""
error = Exception("Something went wrong")
message = get_user_friendly_error_message(error)
assert "API Error" in message
assert "Something went wrong" in message
def test_error_message_with_model_name_param(self) -> None:
"""Should use provided model_name parameter when not in error."""
error = Exception("status_code: 403, body: Forbidden")
message = get_user_friendly_error_message(error, model_name="test-model")
assert "test-model" in message
class TestValidateHfToken:
"""Tests for validate_hf_token function."""
def test_valid_token(self) -> None:
"""Should validate a valid token."""
token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
is_valid, error_msg = validate_hf_token(token)
assert is_valid is True
assert error_msg is None
def test_none_token(self) -> None:
"""Should reject None token."""
is_valid, error_msg = validate_hf_token(None)
assert is_valid is False
assert "None or empty" in error_msg
def test_empty_token(self) -> None:
"""Should reject empty token."""
is_valid, error_msg = validate_hf_token("")
assert is_valid is False
assert "None or empty" in error_msg
def test_non_string_token(self) -> None:
"""Should reject non-string token."""
is_valid, error_msg = validate_hf_token(123) # type: ignore[arg-type]
assert is_valid is False
assert "not a string" in error_msg
def test_short_token(self) -> None:
"""Should reject token that's too short."""
is_valid, error_msg = validate_hf_token("hf_123")
assert is_valid is False
assert "too short" in error_msg
def test_oauth_token_format(self) -> None:
"""Should accept OAuth tokens (may not start with hf_)."""
# OAuth tokens may have different formats
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
is_valid, error_msg = validate_hf_token(token)
assert is_valid is True
assert error_msg is None
class TestLogTokenInfo:
"""Tests for log_token_info function."""
@patch("src.utils.hf_error_handler.logger")
def test_log_valid_token(self, mock_logger) -> None:
"""Should log token info for valid token."""
token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
log_token_info(token, context="test")
mock_logger.debug.assert_called_once()
call_args = mock_logger.debug.call_args
assert call_args[0][0] == "Token validation"
assert call_args[1]["context"] == "test"
assert call_args[1]["has_token"] is True
assert call_args[1]["is_valid"] is True
assert call_args[1]["token_length"] == len(token)
assert "token_prefix" in call_args[1]
@patch("src.utils.hf_error_handler.logger")
def test_log_none_token(self, mock_logger) -> None:
"""Should log None token info."""
log_token_info(None, context="test")
mock_logger.debug.assert_called_once()
call_args = mock_logger.debug.call_args
assert call_args[0][0] == "Token validation"
assert call_args[1]["context"] == "test"
assert call_args[1]["has_token"] is False
class TestShouldRetryWithFallback:
"""Tests for should_retry_with_fallback function."""
def test_403_error_should_retry(self) -> None:
"""Should retry for 403 errors."""
error = Exception("status_code: 403, model_name: test-model, body: Forbidden")
assert should_retry_with_fallback(error) is True
def test_422_error_should_retry(self) -> None:
"""Should retry for 422 errors."""
error = Exception("status_code: 422, model_name: test-model, body: Unprocessable Entity")
assert should_retry_with_fallback(error) is True
def test_model_specific_error_should_retry(self) -> None:
"""Should retry for model-specific errors."""
error = Exception("status_code: 500, model_name: test-model, body: Error")
assert should_retry_with_fallback(error) is True
def test_generic_error_should_not_retry(self) -> None:
"""Should not retry for generic errors without model info."""
error = Exception("Something went wrong")
assert should_retry_with_fallback(error) is False
class TestGetFallbackModels:
"""Tests for get_fallback_models function."""
def test_get_fallback_models_default(self) -> None:
"""Should return default fallback models."""
fallbacks = get_fallback_models()
assert len(fallbacks) > 0
assert "meta-llama/Llama-3.1-8B-Instruct" in fallbacks
assert isinstance(fallbacks, list)
def test_get_fallback_models_excludes_original(self) -> None:
"""Should exclude original model from fallbacks."""
original = "meta-llama/Llama-3.1-8B-Instruct"
fallbacks = get_fallback_models(original_model=original)
assert original not in fallbacks
assert len(fallbacks) > 0
def test_get_fallback_models_with_unknown_original(self) -> None:
"""Should return all fallbacks if original is not in list."""
original = "unknown/model"
fallbacks = get_fallback_models(original_model=original)
# Should still have all fallbacks since original is not in the list
assert len(fallbacks) >= 3 # At least 3 fallback models
|