DeepCritical / tests /unit /utils /test_hf_error_handler.py
Joseph Pollack
attempts fix 403 and settings
8fa2ce6 unverified
raw
history blame
9.35 kB
"""Unit tests for HuggingFace error handling utilities."""
from unittest.mock import patch
import pytest
from src.utils.hf_error_handler import (
extract_error_details,
get_fallback_models,
get_user_friendly_error_message,
log_token_info,
should_retry_with_fallback,
validate_hf_token,
)
class TestExtractErrorDetails:
"""Tests for extract_error_details function."""
def test_extract_403_error(self) -> None:
"""Should extract 403 error details correctly."""
error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
details = extract_error_details(error)
assert details["status_code"] == 403
assert details["model_name"] == "Qwen/Qwen3-Next-80B-A3B-Thinking"
assert details["body"] == "Forbidden"
assert details["error_type"] == "http_403"
assert details["is_auth_error"] is True
assert details["is_model_error"] is False
def test_extract_422_error(self) -> None:
"""Should extract 422 error details correctly."""
error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
details = extract_error_details(error)
assert details["status_code"] == 422
assert details["model_name"] == "meta-llama/Llama-3.1-70B-Instruct"
assert details["body"] == "Unprocessable Entity"
assert details["error_type"] == "http_422"
assert details["is_auth_error"] is False
assert details["is_model_error"] is True
def test_extract_partial_error(self) -> None:
"""Should handle partial error information."""
error = Exception("status_code: 500")
details = extract_error_details(error)
assert details["status_code"] == 500
assert details["model_name"] is None
assert details["body"] is None
assert details["error_type"] == "http_500"
assert details["is_auth_error"] is False
assert details["is_model_error"] is False
def test_extract_generic_error(self) -> None:
"""Should handle generic errors without status codes."""
error = Exception("Something went wrong")
details = extract_error_details(error)
assert details["status_code"] is None
assert details["model_name"] is None
assert details["body"] is None
assert details["error_type"] == "unknown"
assert details["is_auth_error"] is False
assert details["is_model_error"] is False
class TestGetUserFriendlyErrorMessage:
"""Tests for get_user_friendly_error_message function."""
def test_403_error_message(self) -> None:
"""Should generate user-friendly 403 error message."""
error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
message = get_user_friendly_error_message(error)
assert "Authentication Error" in message
assert "inference-api" in message
assert "Qwen/Qwen3-Next-80B-A3B-Thinking" in message
assert "Forbidden" in message
def test_422_error_message(self) -> None:
"""Should generate user-friendly 422 error message."""
error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
message = get_user_friendly_error_message(error)
assert "Model Compatibility Error" in message
assert "meta-llama/Llama-3.1-70B-Instruct" in message
assert "Unprocessable Entity" in message
def test_generic_error_message(self) -> None:
"""Should generate generic error message for unknown errors."""
error = Exception("Something went wrong")
message = get_user_friendly_error_message(error)
assert "API Error" in message
assert "Something went wrong" in message
def test_error_message_with_model_name_param(self) -> None:
"""Should use provided model_name parameter when not in error."""
error = Exception("status_code: 403, body: Forbidden")
message = get_user_friendly_error_message(error, model_name="test-model")
assert "test-model" in message
class TestValidateHfToken:
"""Tests for validate_hf_token function."""
def test_valid_token(self) -> None:
"""Should validate a valid token."""
token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
is_valid, error_msg = validate_hf_token(token)
assert is_valid is True
assert error_msg is None
def test_none_token(self) -> None:
"""Should reject None token."""
is_valid, error_msg = validate_hf_token(None)
assert is_valid is False
assert "None or empty" in error_msg
def test_empty_token(self) -> None:
"""Should reject empty token."""
is_valid, error_msg = validate_hf_token("")
assert is_valid is False
assert "None or empty" in error_msg
def test_non_string_token(self) -> None:
"""Should reject non-string token."""
is_valid, error_msg = validate_hf_token(123) # type: ignore[arg-type]
assert is_valid is False
assert "not a string" in error_msg
def test_short_token(self) -> None:
"""Should reject token that's too short."""
is_valid, error_msg = validate_hf_token("hf_123")
assert is_valid is False
assert "too short" in error_msg
def test_oauth_token_format(self) -> None:
"""Should accept OAuth tokens (may not start with hf_)."""
# OAuth tokens may have different formats
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
is_valid, error_msg = validate_hf_token(token)
assert is_valid is True
assert error_msg is None
class TestLogTokenInfo:
"""Tests for log_token_info function."""
@patch("src.utils.hf_error_handler.logger")
def test_log_valid_token(self, mock_logger) -> None:
"""Should log token info for valid token."""
token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
log_token_info(token, context="test")
mock_logger.debug.assert_called_once()
call_args = mock_logger.debug.call_args
assert call_args[0][0] == "Token validation"
assert call_args[1]["context"] == "test"
assert call_args[1]["has_token"] is True
assert call_args[1]["is_valid"] is True
assert call_args[1]["token_length"] == len(token)
assert "token_prefix" in call_args[1]
@patch("src.utils.hf_error_handler.logger")
def test_log_none_token(self, mock_logger) -> None:
"""Should log None token info."""
log_token_info(None, context="test")
mock_logger.debug.assert_called_once()
call_args = mock_logger.debug.call_args
assert call_args[0][0] == "Token validation"
assert call_args[1]["context"] == "test"
assert call_args[1]["has_token"] is False
class TestShouldRetryWithFallback:
"""Tests for should_retry_with_fallback function."""
def test_403_error_should_retry(self) -> None:
"""Should retry for 403 errors."""
error = Exception("status_code: 403, model_name: test-model, body: Forbidden")
assert should_retry_with_fallback(error) is True
def test_422_error_should_retry(self) -> None:
"""Should retry for 422 errors."""
error = Exception("status_code: 422, model_name: test-model, body: Unprocessable Entity")
assert should_retry_with_fallback(error) is True
def test_model_specific_error_should_retry(self) -> None:
"""Should retry for model-specific errors."""
error = Exception("status_code: 500, model_name: test-model, body: Error")
assert should_retry_with_fallback(error) is True
def test_generic_error_should_not_retry(self) -> None:
"""Should not retry for generic errors without model info."""
error = Exception("Something went wrong")
assert should_retry_with_fallback(error) is False
class TestGetFallbackModels:
"""Tests for get_fallback_models function."""
def test_get_fallback_models_default(self) -> None:
"""Should return default fallback models."""
fallbacks = get_fallback_models()
assert len(fallbacks) > 0
assert "meta-llama/Llama-3.1-8B-Instruct" in fallbacks
assert isinstance(fallbacks, list)
def test_get_fallback_models_excludes_original(self) -> None:
"""Should exclude original model from fallbacks."""
original = "meta-llama/Llama-3.1-8B-Instruct"
fallbacks = get_fallback_models(original_model=original)
assert original not in fallbacks
assert len(fallbacks) > 0
def test_get_fallback_models_with_unknown_original(self) -> None:
"""Should return all fallbacks if original is not in list."""
original = "unknown/model"
fallbacks = get_fallback_models(original_model=original)
# Should still have all fallbacks since original is not in the list
assert len(fallbacks) >= 3 # At least 3 fallback models