Spaces:
Paused
Paused
| import io | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath("../..")) | |
| import asyncio | |
| import gzip | |
| import json | |
| import logging | |
| import time | |
| from unittest.mock import AsyncMock, patch | |
| import pytest | |
| import litellm | |
| from litellm import completion | |
| from litellm._logging import verbose_logger | |
| from litellm.proxy.utils import log_db_metrics, ServiceTypes | |
| from datetime import datetime | |
| import httpx | |
| from prisma.errors import ClientNotConnectedError | |
| # Test async function to decorate | |
| async def sample_db_function(*args, **kwargs): | |
| return "success" | |
| async def sample_proxy_function(*args, **kwargs): | |
| return "success" | |
| async def test_log_db_metrics_success(): | |
| # Mock the proxy_logging_obj | |
| with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
| # Setup mock | |
| mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock() | |
| # Call the decorated function | |
| result = await sample_db_function(parent_otel_span="test_span") | |
| # Assertions | |
| assert result == "success" | |
| mock_proxy_logging.service_logging_obj.async_service_success_hook.assert_called_once() | |
| call_args = ( | |
| mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[ | |
| 1 | |
| ] | |
| ) | |
| assert call_args["service"] == ServiceTypes.DB | |
| assert call_args["call_type"] == "sample_db_function" | |
| assert call_args["parent_otel_span"] == "test_span" | |
| assert isinstance(call_args["duration"], float) | |
| assert isinstance(call_args["start_time"], datetime) | |
| assert isinstance(call_args["end_time"], datetime) | |
| assert "function_name" in call_args["event_metadata"] | |
| async def test_log_db_metrics_duration(): | |
| # Mock the proxy_logging_obj | |
| with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
| # Setup mock | |
| mock_proxy_logging.service_logging_obj.async_service_success_hook = AsyncMock() | |
| # Add a delay to the function to test duration | |
| async def delayed_function(**kwargs): | |
| await asyncio.sleep(1) # 1 second delay | |
| return "success" | |
| # Call the decorated function | |
| start = time.time() | |
| result = await delayed_function(parent_otel_span="test_span") | |
| end = time.time() | |
| # Get the actual duration | |
| actual_duration = end - start | |
| # Get the logged duration from the mock call | |
| call_args = ( | |
| mock_proxy_logging.service_logging_obj.async_service_success_hook.call_args[ | |
| 1 | |
| ] | |
| ) | |
| logged_duration = call_args["duration"] | |
| # Assert the logged duration is approximately equal to actual duration (within 0.1 seconds) | |
| assert abs(logged_duration - actual_duration) < 0.1 | |
| assert result == "success" | |
| async def test_log_db_metrics_failure(): | |
| """ | |
| should log a failure if a prisma error is raised | |
| """ | |
| # Mock the proxy_logging_obj | |
| from prisma.errors import ClientNotConnectedError | |
| with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
| # Setup mock | |
| mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock() | |
| # Create a failing function | |
| async def failing_function(**kwargs): | |
| raise ClientNotConnectedError() | |
| # Call the decorated function and expect it to raise | |
| with pytest.raises(ClientNotConnectedError) as exc_info: | |
| await failing_function(parent_otel_span="test_span") | |
| # Assertions | |
| assert "Client is not connected to the query engine" in str(exc_info.value) | |
| mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once() | |
| call_args = ( | |
| mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[ | |
| 1 | |
| ] | |
| ) | |
| assert call_args["service"] == ServiceTypes.DB | |
| assert call_args["call_type"] == "failing_function" | |
| assert call_args["parent_otel_span"] == "test_span" | |
| assert isinstance(call_args["duration"], float) | |
| assert isinstance(call_args["start_time"], datetime) | |
| assert isinstance(call_args["end_time"], datetime) | |
| assert isinstance(call_args["error"], ClientNotConnectedError) | |
| async def test_log_db_metrics_failure_error_types(exception, should_log): | |
| """ | |
| Why Test? | |
| Users were seeing that non-DB errors were being logged as DB Service Failures | |
| Example a failure to read a value from cache was being logged as a DB Service Failure | |
| Parameterized test to verify: | |
| - DB-related errors (Prisma, httpx) are logged as service failures | |
| - Non-DB errors (ValueError, KeyError, etc.) are not logged | |
| """ | |
| with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: | |
| mock_proxy_logging.service_logging_obj.async_service_failure_hook = AsyncMock() | |
| async def failing_function(**kwargs): | |
| raise exception | |
| # Call the function and expect it to raise the exception | |
| with pytest.raises(type(exception)): | |
| await failing_function(parent_otel_span="test_span") | |
| if should_log: | |
| # Assert failure was logged for DB-related errors | |
| mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_called_once() | |
| call_args = mock_proxy_logging.service_logging_obj.async_service_failure_hook.call_args[ | |
| 1 | |
| ] | |
| assert call_args["service"] == ServiceTypes.DB | |
| assert call_args["call_type"] == "failing_function" | |
| assert call_args["parent_otel_span"] == "test_span" | |
| assert isinstance(call_args["duration"], float) | |
| assert isinstance(call_args["start_time"], datetime) | |
| assert isinstance(call_args["end_time"], datetime) | |
| assert isinstance(call_args["error"], type(exception)) | |
| else: | |
| # Assert failure was NOT logged for non-DB errors | |
| mock_proxy_logging.service_logging_obj.async_service_failure_hook.assert_not_called() | |
| async def test_dd_log_db_spend_failure_metrics(): | |
| from litellm._service_logger import ServiceLogging | |
| from litellm.integrations.datadog.datadog import DataDogLogger | |
| dd_logger = DataDogLogger() | |
| with patch.object(dd_logger, "async_service_failure_hook", new_callable=AsyncMock): | |
| service_logging_obj = ServiceLogging() | |
| litellm.service_callback = [dd_logger] | |
| await service_logging_obj.async_service_failure_hook( | |
| service=ServiceTypes.DB, | |
| call_type="test_call_type", | |
| error="test_error", | |
| duration=1.0, | |
| ) | |