Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from fastapi import HTTPException, Request, status | |
| from prisma import errors as prisma_errors | |
| from prisma.errors import ( | |
| ClientNotConnectedError, | |
| DataError, | |
| ForeignKeyViolationError, | |
| HTTPClientClosedError, | |
| MissingRequiredValueError, | |
| PrismaError, | |
| RawQueryError, | |
| RecordNotFoundError, | |
| TableNotFoundError, | |
| UniqueViolationError, | |
| ) | |
| sys.path.insert( | |
| 0, os.path.abspath("../../..") | |
| ) # Adds the parent directory to the system path | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy._types import ProxyErrorTypes, ProxyException | |
| from litellm.proxy.auth.auth_exception_handler import UserAPIKeyAuthExceptionHandler | |
| async def test_handle_authentication_error_db_unavailable(prisma_error): | |
| handler = UserAPIKeyAuthExceptionHandler() | |
| # Mock request and other dependencies | |
| mock_request = MagicMock() | |
| mock_request_data = {} | |
| mock_route = "/test" | |
| mock_span = None | |
| mock_api_key = "test-key" | |
| # Test with DB connection error when requests are allowed | |
| with patch( | |
| "litellm.proxy.proxy_server.general_settings", | |
| {"allow_requests_on_db_unavailable": True}, | |
| ): | |
| result = await handler._handle_authentication_error( | |
| prisma_error, | |
| mock_request, | |
| mock_request_data, | |
| mock_route, | |
| mock_span, | |
| mock_api_key, | |
| ) | |
| assert result.key_name == "failed-to-connect-to-db" | |
| assert result.token == "failed-to-connect-to-db" | |
| async def test_handle_authentication_error_budget_exceeded(): | |
| handler = UserAPIKeyAuthExceptionHandler() | |
| # Mock request and other dependencies | |
| mock_request = MagicMock() | |
| mock_request_data = {} | |
| mock_route = "/test" | |
| mock_span = None | |
| mock_api_key = "test-key" | |
| # Test with budget exceeded error | |
| with pytest.raises(ProxyException) as exc_info: | |
| from litellm.exceptions import BudgetExceededError | |
| budget_error = BudgetExceededError( | |
| message="Budget exceeded", current_cost=100, max_budget=100 | |
| ) | |
| await handler._handle_authentication_error( | |
| budget_error, | |
| mock_request, | |
| mock_request_data, | |
| mock_route, | |
| mock_span, | |
| mock_api_key, | |
| ) | |
| assert exc_info.value.type == ProxyErrorTypes.budget_exceeded | |
| async def test_route_passed_to_post_call_failure_hook(): | |
| """ | |
| This route is used by proxy track_cost_callback's async_post_call_failure_hook to check if the route is an LLM route | |
| """ | |
| handler = UserAPIKeyAuthExceptionHandler() | |
| # Mock request and other dependencies | |
| mock_request = MagicMock() | |
| mock_request_data = {} | |
| test_route = "/custom/route" | |
| mock_span = None | |
| mock_api_key = "test-key" | |
| # Mock proxy_logging_obj.post_call_failure_hook | |
| with patch( | |
| "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook", | |
| new_callable=AsyncMock, | |
| ) as mock_post_call_failure_hook: | |
| # Test with DB connection error | |
| with patch( | |
| "litellm.proxy.proxy_server.general_settings", | |
| {"allow_requests_on_db_unavailable": False}, | |
| ): | |
| try: | |
| await handler._handle_authentication_error( | |
| PrismaError(), | |
| mock_request, | |
| mock_request_data, | |
| test_route, | |
| mock_span, | |
| mock_api_key, | |
| ) | |
| except Exception as e: | |
| pass | |
| asyncio.sleep(1) | |
| # Verify post_call_failure_hook was called with the correct route | |
| mock_post_call_failure_hook.assert_called_once() | |
| call_args = mock_post_call_failure_hook.call_args[1] | |
| assert call_args["user_api_key_dict"].request_route == test_route | |