Spaces:
Running
Running
| """Unit tests for HuggingFace error handling utilities.""" | |
| from unittest.mock import patch | |
| import pytest | |
| from src.utils.hf_error_handler import ( | |
| extract_error_details, | |
| get_fallback_models, | |
| get_user_friendly_error_message, | |
| log_token_info, | |
| should_retry_with_fallback, | |
| validate_hf_token, | |
| ) | |
| class TestExtractErrorDetails: | |
| """Tests for extract_error_details function.""" | |
| def test_extract_403_error(self) -> None: | |
| """Should extract 403 error details correctly.""" | |
| error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden") | |
| details = extract_error_details(error) | |
| assert details["status_code"] == 403 | |
| assert details["model_name"] == "Qwen/Qwen3-Next-80B-A3B-Thinking" | |
| assert details["body"] == "Forbidden" | |
| assert details["error_type"] == "http_403" | |
| assert details["is_auth_error"] is True | |
| assert details["is_model_error"] is False | |
| def test_extract_422_error(self) -> None: | |
| """Should extract 422 error details correctly.""" | |
| error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity") | |
| details = extract_error_details(error) | |
| assert details["status_code"] == 422 | |
| assert details["model_name"] == "meta-llama/Llama-3.1-70B-Instruct" | |
| assert details["body"] == "Unprocessable Entity" | |
| assert details["error_type"] == "http_422" | |
| assert details["is_auth_error"] is False | |
| assert details["is_model_error"] is True | |
| def test_extract_partial_error(self) -> None: | |
| """Should handle partial error information.""" | |
| error = Exception("status_code: 500") | |
| details = extract_error_details(error) | |
| assert details["status_code"] == 500 | |
| assert details["model_name"] is None | |
| assert details["body"] is None | |
| assert details["error_type"] == "http_500" | |
| assert details["is_auth_error"] is False | |
| assert details["is_model_error"] is False | |
| def test_extract_generic_error(self) -> None: | |
| """Should handle generic errors without status codes.""" | |
| error = Exception("Something went wrong") | |
| details = extract_error_details(error) | |
| assert details["status_code"] is None | |
| assert details["model_name"] is None | |
| assert details["body"] is None | |
| assert details["error_type"] == "unknown" | |
| assert details["is_auth_error"] is False | |
| assert details["is_model_error"] is False | |
| class TestGetUserFriendlyErrorMessage: | |
| """Tests for get_user_friendly_error_message function.""" | |
| def test_403_error_message(self) -> None: | |
| """Should generate user-friendly 403 error message.""" | |
| error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden") | |
| message = get_user_friendly_error_message(error) | |
| assert "Authentication Error" in message | |
| assert "inference-api" in message | |
| assert "Qwen/Qwen3-Next-80B-A3B-Thinking" in message | |
| assert "Forbidden" in message | |
| def test_422_error_message(self) -> None: | |
| """Should generate user-friendly 422 error message.""" | |
| error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity") | |
| message = get_user_friendly_error_message(error) | |
| assert "Model Compatibility Error" in message | |
| assert "meta-llama/Llama-3.1-70B-Instruct" in message | |
| assert "Unprocessable Entity" in message | |
| def test_generic_error_message(self) -> None: | |
| """Should generate generic error message for unknown errors.""" | |
| error = Exception("Something went wrong") | |
| message = get_user_friendly_error_message(error) | |
| assert "API Error" in message | |
| assert "Something went wrong" in message | |
| def test_error_message_with_model_name_param(self) -> None: | |
| """Should use provided model_name parameter when not in error.""" | |
| error = Exception("status_code: 403, body: Forbidden") | |
| message = get_user_friendly_error_message(error, model_name="test-model") | |
| assert "test-model" in message | |
| class TestValidateHfToken: | |
| """Tests for validate_hf_token function.""" | |
| def test_valid_token(self) -> None: | |
| """Should validate a valid token.""" | |
| token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" | |
| is_valid, error_msg = validate_hf_token(token) | |
| assert is_valid is True | |
| assert error_msg is None | |
| def test_none_token(self) -> None: | |
| """Should reject None token.""" | |
| is_valid, error_msg = validate_hf_token(None) | |
| assert is_valid is False | |
| assert "None or empty" in error_msg | |
| def test_empty_token(self) -> None: | |
| """Should reject empty token.""" | |
| is_valid, error_msg = validate_hf_token("") | |
| assert is_valid is False | |
| assert "None or empty" in error_msg | |
| def test_non_string_token(self) -> None: | |
| """Should reject non-string token.""" | |
| is_valid, error_msg = validate_hf_token(123) # type: ignore[arg-type] | |
| assert is_valid is False | |
| assert "not a string" in error_msg | |
| def test_short_token(self) -> None: | |
| """Should reject token that's too short.""" | |
| is_valid, error_msg = validate_hf_token("hf_123") | |
| assert is_valid is False | |
| assert "too short" in error_msg | |
| def test_oauth_token_format(self) -> None: | |
| """Should accept OAuth tokens (may not start with hf_).""" | |
| # OAuth tokens may have different formats | |
| token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ" | |
| is_valid, error_msg = validate_hf_token(token) | |
| assert is_valid is True | |
| assert error_msg is None | |
| class TestLogTokenInfo: | |
| """Tests for log_token_info function.""" | |
| def test_log_valid_token(self, mock_logger) -> None: | |
| """Should log token info for valid token.""" | |
| token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" | |
| log_token_info(token, context="test") | |
| mock_logger.debug.assert_called_once() | |
| call_args = mock_logger.debug.call_args | |
| assert call_args[0][0] == "Token validation" | |
| assert call_args[1]["context"] == "test" | |
| assert call_args[1]["has_token"] is True | |
| assert call_args[1]["is_valid"] is True | |
| assert call_args[1]["token_length"] == len(token) | |
| assert "token_prefix" in call_args[1] | |
| def test_log_none_token(self, mock_logger) -> None: | |
| """Should log None token info.""" | |
| log_token_info(None, context="test") | |
| mock_logger.debug.assert_called_once() | |
| call_args = mock_logger.debug.call_args | |
| assert call_args[0][0] == "Token validation" | |
| assert call_args[1]["context"] == "test" | |
| assert call_args[1]["has_token"] is False | |
| class TestShouldRetryWithFallback: | |
| """Tests for should_retry_with_fallback function.""" | |
| def test_403_error_should_retry(self) -> None: | |
| """Should retry for 403 errors.""" | |
| error = Exception("status_code: 403, model_name: test-model, body: Forbidden") | |
| assert should_retry_with_fallback(error) is True | |
| def test_422_error_should_retry(self) -> None: | |
| """Should retry for 422 errors.""" | |
| error = Exception("status_code: 422, model_name: test-model, body: Unprocessable Entity") | |
| assert should_retry_with_fallback(error) is True | |
| def test_model_specific_error_should_retry(self) -> None: | |
| """Should retry for model-specific errors.""" | |
| error = Exception("status_code: 500, model_name: test-model, body: Error") | |
| assert should_retry_with_fallback(error) is True | |
| def test_generic_error_should_not_retry(self) -> None: | |
| """Should not retry for generic errors without model info.""" | |
| error = Exception("Something went wrong") | |
| assert should_retry_with_fallback(error) is False | |
| class TestGetFallbackModels: | |
| """Tests for get_fallback_models function.""" | |
| def test_get_fallback_models_default(self) -> None: | |
| """Should return default fallback models.""" | |
| fallbacks = get_fallback_models() | |
| assert len(fallbacks) > 0 | |
| assert "meta-llama/Llama-3.1-8B-Instruct" in fallbacks | |
| assert isinstance(fallbacks, list) | |
| def test_get_fallback_models_excludes_original(self) -> None: | |
| """Should exclude original model from fallbacks.""" | |
| original = "meta-llama/Llama-3.1-8B-Instruct" | |
| fallbacks = get_fallback_models(original_model=original) | |
| assert original not in fallbacks | |
| assert len(fallbacks) > 0 | |
| def test_get_fallback_models_with_unknown_original(self) -> None: | |
| """Should return all fallbacks if original is not in list.""" | |
| original = "unknown/model" | |
| fallbacks = get_fallback_models(original_model=original) | |
| # Should still have all fallbacks since original is not in the list | |
| assert len(fallbacks) >= 3 # At least 3 fallback models | |