Spaces:
Paused
Paused
| import copy | |
| import uuid | |
| from unittest.mock import AsyncMock, MagicMock | |
| import pytest | |
| from fastapi import Request, status | |
| from fastapi.responses import StreamingResponse | |
| import litellm | |
| from litellm.integrations.opentelemetry import UserAPIKeyAuth | |
| from litellm.proxy.common_request_processing import ( | |
| ProxyBaseLLMRequestProcessing, | |
| ProxyConfig, | |
| _parse_event_data_for_error, | |
| create_streaming_response, | |
| ) | |
| from litellm.proxy.utils import ProxyLogging | |
| class TestProxyBaseLLMRequestProcessing: | |
| async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id( | |
| self, monkeypatch | |
| ): | |
| processing_obj = ProxyBaseLLMRequestProcessing(data={}) | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers = {} | |
| async def mock_add_litellm_data_to_request(*args, **kwargs): | |
| return {} | |
| async def mock_common_processing_pre_call_logic( | |
| user_api_key_dict, data, call_type | |
| ): | |
| data_copy = copy.deepcopy(data) | |
| return data_copy | |
| mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) | |
| mock_proxy_logging_obj.pre_call_hook = AsyncMock( | |
| side_effect=mock_common_processing_pre_call_logic | |
| ) | |
| monkeypatch.setattr( | |
| litellm.proxy.common_request_processing, | |
| "add_litellm_data_to_request", | |
| mock_add_litellm_data_to_request, | |
| ) | |
| mock_general_settings = {} | |
| mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) | |
| mock_proxy_config = MagicMock(spec=ProxyConfig) | |
| route_type = "acompletion" | |
| # Call the actual method. | |
| ( | |
| returned_data, | |
| logging_obj, | |
| ) = await processing_obj.common_processing_pre_call_logic( | |
| request=mock_request, | |
| general_settings=mock_general_settings, | |
| user_api_key_dict=mock_user_api_key_dict, | |
| proxy_logging_obj=mock_proxy_logging_obj, | |
| proxy_config=mock_proxy_config, | |
| route_type=route_type, | |
| ) | |
| mock_proxy_logging_obj.pre_call_hook.assert_called_once() | |
| _, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args | |
| data_passed = call_kwargs.get("data", {}) | |
| assert "litellm_call_id" in data_passed | |
| try: | |
| uuid.UUID(data_passed["litellm_call_id"]) | |
| except ValueError: | |
| pytest.fail("litellm_call_id is not a valid UUID") | |
| assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"] | |
| class TestCommonRequestProcessingHelpers: | |
| async def consume_stream(self, streaming_response: StreamingResponse) -> list: | |
| content = [] | |
| async for chunk_bytes in streaming_response.body_iterator: | |
| content.append(chunk_bytes) | |
| return content | |
| async def test_parse_event_data_for_error(self, event_line, expected_code): | |
| assert await _parse_event_data_for_error(event_line) == expected_code | |
| async def test_create_streaming_response_first_chunk_is_error(self): | |
| async def mock_generator(): | |
| yield 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n' | |
| yield 'data: {"content": "more data"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == status.HTTP_403_FORBIDDEN | |
| content = await self.consume_stream(response) | |
| assert content == [ | |
| 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n', | |
| 'data: {"content": "more data"}\n\n', | |
| "data: [DONE]\n\n", | |
| ] | |
| async def test_create_streaming_response_first_chunk_not_error(self): | |
| async def mock_generator(): | |
| yield 'data: {"content": "first part"}\n\n' | |
| yield 'data: {"content": "second part"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == status.HTTP_200_OK | |
| content = await self.consume_stream(response) | |
| assert content == [ | |
| 'data: {"content": "first part"}\n\n', | |
| 'data: {"content": "second part"}\n\n', | |
| "data: [DONE]\n\n", | |
| ] | |
| async def test_create_streaming_response_empty_generator(self): | |
| async def mock_generator(): | |
| if False: # Never yields | |
| yield | |
| # Implicitly raises StopAsyncIteration | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == status.HTTP_200_OK | |
| content = await self.consume_stream(response) | |
| assert content == [] | |
| async def test_create_streaming_response_generator_raises_stop_async_iteration_immediately( | |
| self, | |
| ): | |
| mock_gen = AsyncMock() | |
| mock_gen.__anext__.side_effect = StopAsyncIteration | |
| response = await create_streaming_response(mock_gen, "text/event-stream", {}) | |
| assert response.status_code == status.HTTP_200_OK | |
| content = await self.consume_stream(response) | |
| assert content == [] | |
| async def test_create_streaming_response_generator_raises_unexpected_exception( | |
| self, | |
| ): | |
| mock_gen = AsyncMock() | |
| mock_gen.__anext__.side_effect = ValueError("Test error from generator") | |
| response = await create_streaming_response(mock_gen, "text/event-stream", {}) | |
| assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR | |
| content = await self.consume_stream(response) | |
| expected_error_data = { | |
| "error": { | |
| "message": "Error processing stream start", | |
| "code": status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| } | |
| } | |
| assert len(content) == 2 | |
| # Use json.dumps to match the formatting in create_streaming_response's exception handler | |
| import json | |
| assert content[0] == f"data: {json.dumps(expected_error_data)}\n\n" | |
| assert content[1] == "data: [DONE]\n\n" | |
| async def test_create_streaming_response_first_chunk_error_string_code(self): | |
| async def mock_generator(): | |
| yield 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n' | |
| yield "data: [DONE]\n\n" | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS | |
| content = await self.consume_stream(response) | |
| assert content == [ | |
| 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n', | |
| "data: [DONE]\n\n", | |
| ] | |
| async def test_create_streaming_response_custom_headers(self): | |
| async def mock_generator(): | |
| yield 'data: {"content": "data"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| custom_headers = {"X-Custom-Header": "TestValue"} | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", custom_headers | |
| ) | |
| assert response.headers["x-custom-header"] == "TestValue" | |
| async def test_create_streaming_response_non_default_status_code(self): | |
| async def mock_generator(): | |
| yield 'data: {"content": "data"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| response = await create_streaming_response( | |
| mock_generator(), | |
| "text/event-stream", | |
| {}, | |
| default_status_code=status.HTTP_201_CREATED, | |
| ) | |
| assert response.status_code == status.HTTP_201_CREATED | |
| content = await self.consume_stream(response) | |
| assert content == [ | |
| 'data: {"content": "data"}\n\n', | |
| "data: [DONE]\n\n", | |
| ] | |
| async def test_create_streaming_response_first_chunk_is_done(self): | |
| async def mock_generator(): | |
| yield "data: [DONE]\n\n" | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == status.HTTP_200_OK # Default status | |
| content = await self.consume_stream(response) | |
| assert content == ["data: [DONE]\n\n"] | |
| async def test_create_streaming_response_first_chunk_is_empty_data(self): | |
| async def mock_generator(): | |
| yield "data: \n\n" | |
| yield 'data: {"content": "actual data"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == status.HTTP_200_OK # Default status | |
| content = await self.consume_stream(response) | |
| assert content == [ | |
| "data: \n\n", | |
| 'data: {"content": "actual data"}\n\n', | |
| "data: [DONE]\n\n", | |
| ] | |
| async def test_create_streaming_response_all_chunks_have_dd_trace(self): | |
| """Test that all stream chunks are wrapped with dd trace at the streaming generator level""" | |
| import json | |
| from unittest.mock import patch | |
| # Create a mock tracer | |
| mock_tracer = MagicMock() | |
| mock_span = MagicMock() | |
| mock_tracer.trace.return_value.__enter__.return_value = mock_span | |
| mock_tracer.trace.return_value.__exit__.return_value = None | |
| # Mock generator with multiple chunks | |
| async def mock_generator(): | |
| yield 'data: {"content": "chunk 1"}\n\n' | |
| yield 'data: {"content": "chunk 2"}\n\n' | |
| yield 'data: {"content": "chunk 3"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| # Patch the tracer in the common_request_processing module | |
| with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| assert response.status_code == 200 | |
| # Consume the stream to trigger the tracer calls | |
| content = await self.consume_stream(response) | |
| # Verify all chunks are present | |
| assert len(content) == 4 | |
| assert content[0] == 'data: {"content": "chunk 1"}\n\n' | |
| assert content[1] == 'data: {"content": "chunk 2"}\n\n' | |
| assert content[2] == 'data: {"content": "chunk 3"}\n\n' | |
| assert content[3] == "data: [DONE]\n\n" | |
| # Verify that tracer.trace was called for each chunk (4 chunks total) | |
| assert mock_tracer.trace.call_count == 4 | |
| # Verify that each call was made with the correct operation name | |
| expected_calls = [ | |
| (("streaming.chunk.yield",), {}), | |
| (("streaming.chunk.yield",), {}), | |
| (("streaming.chunk.yield",), {}), | |
| (("streaming.chunk.yield",), {}), | |
| ] | |
| actual_calls = mock_tracer.trace.call_args_list | |
| assert len(actual_calls) == 4 | |
| for i, call in enumerate(actual_calls): | |
| args, kwargs = call | |
| assert ( | |
| args[0] == "streaming.chunk.yield" | |
| ), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}" | |
| async def test_create_streaming_response_dd_trace_with_error_chunk(self): | |
| """Test that dd trace is applied even when the first chunk contains an error""" | |
| from unittest.mock import patch | |
| # Create a mock tracer | |
| mock_tracer = MagicMock() | |
| mock_span = MagicMock() | |
| mock_tracer.trace.return_value.__enter__.return_value = mock_span | |
| mock_tracer.trace.return_value.__exit__.return_value = None | |
| # Mock generator with error in first chunk | |
| async def mock_generator(): | |
| yield 'data: {"error": {"code": 400, "message": "bad request"}}\n\n' | |
| yield 'data: {"content": "chunk after error"}\n\n' | |
| yield "data: [DONE]\n\n" | |
| # Patch the tracer in the common_request_processing module | |
| with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): | |
| response = await create_streaming_response( | |
| mock_generator(), "text/event-stream", {} | |
| ) | |
| # Even with error, status should be set to error code but tracing should still work | |
| assert response.status_code == 400 | |
| # Consume the stream to trigger the tracer calls | |
| content = await self.consume_stream(response) | |
| # Verify all chunks are present | |
| assert len(content) == 3 | |
| # Verify that tracer.trace was called for each chunk | |
| assert mock_tracer.trace.call_count == 3 | |
| # Verify that each call was made with the correct operation name | |
| actual_calls = mock_tracer.trace.call_args_list | |
| assert len(actual_calls) == 3 | |
| for i, call in enumerate(actual_calls): | |
| args, kwargs = call | |
| assert ( | |
| args[0] == "streaming.chunk.yield" | |
| ), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}" | |