AIstudioProxyAPI / tests /stream /test_proxy_server_edge_cases.py
peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
12.1 kB
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from stream.proxy_server import ProxyServer
@pytest.fixture
def mock_deps():
with (
patch("stream.proxy_server.CertificateManager") as MockCertMgr,
patch("stream.proxy_server.ProxyConnector") as MockConnector,
patch("stream.proxy_server.HttpInterceptor") as MockInterceptor,
patch("pathlib.Path.mkdir"),
):
mock_cert_mgr = MockCertMgr.return_value
mock_connector = MockConnector.return_value
mock_interceptor = MockInterceptor.return_value
yield {
"cert_mgr": mock_cert_mgr,
"connector": mock_connector,
"interceptor": mock_interceptor,
}
@pytest.fixture
def proxy_server(mock_deps):
with patch("logging.getLogger"):
server = ProxyServer(
host="127.0.0.1", port=8080, intercept_domains=["example.com"]
)
return server
@pytest.mark.asyncio
async def test_handle_client_empty_request(proxy_server):
"""Test handling client with empty request line."""
mock_reader = AsyncMock()
mock_writer = MagicMock()
mock_writer.wait_closed = AsyncMock()
mock_reader.readline.return_value = b""
await proxy_server.handle_client(mock_reader, mock_writer)
mock_writer.close.assert_called()
@pytest.mark.asyncio
async def test_handle_client_exception(proxy_server):
"""Test handling client with exception during read."""
mock_reader = AsyncMock()
mock_writer = MagicMock()
mock_writer.wait_closed = AsyncMock()
mock_reader.readline.side_effect = Exception("Read error")
await proxy_server.handle_client(mock_reader, mock_writer)
# Logger should have logged the error
proxy_server.logger.error.assert_called()
mock_writer.close.assert_called()
@pytest.mark.asyncio
async def test_handle_connect_no_transport(proxy_server, mock_deps):
"""Test CONNECT when transport is None."""
mock_reader = AsyncMock()
mock_writer = MagicMock()
mock_writer.drain = AsyncMock() # Fix: drain must be awaitable
# Setup for interception
proxy_server.should_intercept = MagicMock(return_value=True)
# Mock transport as None
mock_writer.transport = None
await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443")
# Should warn and return
proxy_server.logger.warning.assert_called_with(
"Client writer transport is None for example.com:443 before TLS upgrade. Closing."
)
@pytest.mark.asyncio
async def test_handle_connect_start_tls_fail(proxy_server, mock_deps):
"""Test CONNECT when start_tls returns None."""
mock_reader = AsyncMock()
mock_writer = MagicMock()
mock_writer.drain = AsyncMock() # Fix: drain must be awaitable
mock_transport = MagicMock()
mock_writer.transport = mock_transport
proxy_server.should_intercept = MagicMock(return_value=True)
mock_loop = MagicMock()
mock_loop.start_tls = AsyncMock(return_value=None)
with (
patch("asyncio.get_running_loop", return_value=mock_loop),
patch("ssl.create_default_context"),
):
await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443")
proxy_server.logger.error.assert_called_with(
"loop.start_tls returned None for example.com:443, which is unexpected. Closing connection.",
exc_info=True,
)
mock_writer.close.assert_called()
@pytest.mark.asyncio
async def test_forward_data_with_interception_invalid_http(proxy_server, mock_deps):
"""Test forwarding with interception when request line is invalid."""
client_reader = AsyncMock()
client_writer = MagicMock()
server_reader = AsyncMock()
server_writer = MagicMock()
server_writer.drain = AsyncMock()
# Capture written data because client_buffer is cleared
written_data = []
def capture_write(data):
written_data.append(bytes(data)) # Convert bytearray to bytes copy
server_writer.write.side_effect = capture_write
# Invalid HTTP request (no spaces)
invalid_request = b"INVALID_REQUEST\r\nHeader: val\r\n\r\n"
client_reader.read.side_effect = [invalid_request, b""]
server_reader.read.return_value = b"" # Server closes immediately
await proxy_server._forward_data_with_interception(
client_reader, client_writer, server_reader, server_writer, "example.com"
)
# Should have forwarded raw buffer
assert invalid_request in written_data
@pytest.mark.asyncio
async def test_forward_data_with_interception_partial_data(proxy_server, mock_deps):
"""Test forwarding with interception when data arrives in chunks."""
client_reader = AsyncMock()
MagicMock()
server_reader = AsyncMock()
server_writer = MagicMock()
server_writer.drain = AsyncMock()
# Split request into chunks
chunk1 = b"POST /path "
chunk2 = b"HTTP/1.1\r\n"
chunk3 = b"Host: example.com\r\n\r\nBody"
client_reader.read.side_effect = [chunk1, chunk2, chunk3, b""]
server_reader.read.return_value = b"" # Server closes immediately
# Setup interceptor to avoid errors
mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed")
# See explanation in bug reproduction test
pass
@pytest.mark.asyncio
async def test_forward_data_with_interception_split_headers_bug_reproduction(
proxy_server, mock_deps
):
"""
Test that split headers cause interception to be skipped (or fail).
This test documents current behavior which might be buggy.
"""
client_reader = AsyncMock()
client_writer = MagicMock()
server_reader = AsyncMock()
server_writer = MagicMock()
server_writer.drain = AsyncMock()
chunk1 = b"POST /GenerateContent HTTP/1.1\r\nHost: e"
chunk2 = b"xample.com\r\n\r\nBody"
client_reader.read.side_effect = [chunk1, chunk2, b""]
server_reader.read.return_value = b"" # Server closes immediately
await proxy_server._forward_data_with_interception(
client_reader, client_writer, server_reader, server_writer, "example.com"
)
# Because of the potential bug, it will forward chunk1 immediately.
server_writer.write.assert_any_call(chunk1)
# Verify process_request was NOT called
assert not mock_deps["interceptor"].process_request.called
@pytest.mark.asyncio
async def test_forward_data_with_interception_http_error_response(
proxy_server, mock_deps
):
"""Test interception with HTTP error response (4xx/5xx).
When the upstream returns an error status, we should:
1. Log the error
2. Send an error payload to the queue (fail-fast)
"""
import json
client_reader = AsyncMock()
client_writer = MagicMock()
client_writer.write = MagicMock()
server_reader = AsyncMock()
server_writer = MagicMock()
server_writer.drain = AsyncMock()
# Setup a queue to capture error payload
mock_queue = MagicMock()
proxy_server.queue = mock_queue
# Client sends a GenerateContent request
client_data = b"POST /GenerateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
client_reader.read.side_effect = [client_data, b""]
# Server returns 429 Too Many Requests
server_response = b"HTTP/1.1 429 Too Many Requests\r\nContent-Type: text/plain\r\n\r\nRate limited"
server_reader.read.side_effect = [server_response, b""]
# Setup interceptor
mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed")
await proxy_server._forward_data_with_interception(
client_reader, client_writer, server_reader, server_writer, "example.com"
)
# Verify error payload was sent to queue
mock_queue.put.assert_called()
call_args = mock_queue.put.call_args[0][0]
parsed = json.loads(call_args)
assert parsed["error"] is True
assert parsed["status"] == 429
assert "429" in parsed["message"]
assert parsed["done"] is True
@pytest.mark.asyncio
async def test_handle_connect_connection_error_with_interception(
proxy_server, mock_deps
):
"""Test CONNECT handling when connection to server fails (with interception).
When create_connection fails after TLS upgrade, we should:
1. Log the error
2. Close the client writer properly
"""
mock_reader = AsyncMock()
mock_writer = MagicMock()
mock_writer.drain = AsyncMock()
mock_writer.close = MagicMock()
mock_writer.wait_closed = AsyncMock()
# Setup for interception
proxy_server.should_intercept = MagicMock(return_value=True)
mock_transport = MagicMock()
mock_writer.transport = mock_transport
mock_loop = MagicMock()
new_transport = MagicMock()
mock_loop.start_tls = AsyncMock(return_value=new_transport)
# Mock connector to raise connection error
mock_deps["connector"].create_connection = AsyncMock(
side_effect=ConnectionRefusedError("Connection refused")
)
with (
patch("asyncio.get_running_loop", return_value=mock_loop),
patch("ssl.create_default_context"),
patch("asyncio.StreamWriter"), # Mock StreamWriter constructor
):
await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443")
# Should log the error
proxy_server.logger.error.assert_called()
error_call = str(proxy_server.logger.error.call_args)
assert "example.com" in error_call
@pytest.mark.asyncio
async def test_handle_connect_connection_error_no_interception(proxy_server, mock_deps):
"""Test CONNECT handling when connection to server fails (no interception).
When create_connection fails without interception, we should:
1. Log the error
2. Close the original writer properly
"""
mock_reader = AsyncMock()
mock_writer = MagicMock()
mock_writer.drain = AsyncMock()
mock_writer.close = MagicMock()
mock_writer.wait_closed = AsyncMock()
# Setup for no interception
proxy_server.should_intercept = MagicMock(return_value=False)
# Mock connector to raise connection error
mock_deps["connector"].create_connection = AsyncMock(
side_effect=TimeoutError("Connection timed out")
)
await proxy_server._handle_connect(mock_reader, mock_writer, "other.com:443")
# Should log the error
proxy_server.logger.error.assert_called()
error_call = str(proxy_server.logger.error.call_args)
assert "other.com" in error_call
# Should close the writer
mock_writer.close.assert_called()
@pytest.mark.asyncio
async def test_forward_data_with_interception_response_parsing_error(
proxy_server, mock_deps
):
"""Test interception when response parsing fails.
If the interceptor throws during response processing, we should:
1. Log the error
2. Continue forwarding data
"""
client_reader = AsyncMock()
client_writer = MagicMock()
client_writer.write = MagicMock()
server_reader = AsyncMock()
server_writer = MagicMock()
server_writer.drain = AsyncMock()
# Client sends a GenerateContent request
client_data = b"POST /GenerateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
client_reader.read.side_effect = [client_data, b""]
# Server returns a valid 200 response
server_response = (
b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nResponse body"
)
server_reader.read.side_effect = [server_response, b""]
# Setup interceptor to fail during response processing
mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed")
mock_deps["interceptor"].process_response = AsyncMock(
side_effect=ValueError("Failed to parse response")
)
await proxy_server._forward_data_with_interception(
client_reader, client_writer, server_reader, server_writer, "example.com"
)
# Should log the error
proxy_server.logger.error.assert_called()
error_call = str(proxy_server.logger.error.call_args)
assert "interception" in error_call.lower() or "response" in error_call.lower()