from unittest.mock import AsyncMock, MagicMock, patch import pytest from stream.proxy_server import ProxyServer @pytest.fixture def mock_deps(): with ( patch("stream.proxy_server.CertificateManager") as MockCertMgr, patch("stream.proxy_server.ProxyConnector") as MockConnector, patch("stream.proxy_server.HttpInterceptor") as MockInterceptor, patch("pathlib.Path.mkdir"), ): mock_cert_mgr = MockCertMgr.return_value mock_connector = MockConnector.return_value mock_interceptor = MockInterceptor.return_value yield { "cert_mgr": mock_cert_mgr, "connector": mock_connector, "interceptor": mock_interceptor, } @pytest.fixture def proxy_server(mock_deps): with patch("logging.getLogger"): server = ProxyServer( host="127.0.0.1", port=8080, intercept_domains=["example.com"] ) return server @pytest.mark.asyncio async def test_handle_client_empty_request(proxy_server): """Test handling client with empty request line.""" mock_reader = AsyncMock() mock_writer = MagicMock() mock_writer.wait_closed = AsyncMock() mock_reader.readline.return_value = b"" await proxy_server.handle_client(mock_reader, mock_writer) mock_writer.close.assert_called() @pytest.mark.asyncio async def test_handle_client_exception(proxy_server): """Test handling client with exception during read.""" mock_reader = AsyncMock() mock_writer = MagicMock() mock_writer.wait_closed = AsyncMock() mock_reader.readline.side_effect = Exception("Read error") await proxy_server.handle_client(mock_reader, mock_writer) # Logger should have logged the error proxy_server.logger.error.assert_called() mock_writer.close.assert_called() @pytest.mark.asyncio async def test_handle_connect_no_transport(proxy_server, mock_deps): """Test CONNECT when transport is None.""" mock_reader = AsyncMock() mock_writer = MagicMock() mock_writer.drain = AsyncMock() # Fix: drain must be awaitable # Setup for interception proxy_server.should_intercept = MagicMock(return_value=True) # Mock transport as None mock_writer.transport = None await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443") # Should warn and return proxy_server.logger.warning.assert_called_with( "Client writer transport is None for example.com:443 before TLS upgrade. Closing." ) @pytest.mark.asyncio async def test_handle_connect_start_tls_fail(proxy_server, mock_deps): """Test CONNECT when start_tls returns None.""" mock_reader = AsyncMock() mock_writer = MagicMock() mock_writer.drain = AsyncMock() # Fix: drain must be awaitable mock_transport = MagicMock() mock_writer.transport = mock_transport proxy_server.should_intercept = MagicMock(return_value=True) mock_loop = MagicMock() mock_loop.start_tls = AsyncMock(return_value=None) with ( patch("asyncio.get_running_loop", return_value=mock_loop), patch("ssl.create_default_context"), ): await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443") proxy_server.logger.error.assert_called_with( "loop.start_tls returned None for example.com:443, which is unexpected. Closing connection.", exc_info=True, ) mock_writer.close.assert_called() @pytest.mark.asyncio async def test_forward_data_with_interception_invalid_http(proxy_server, mock_deps): """Test forwarding with interception when request line is invalid.""" client_reader = AsyncMock() client_writer = MagicMock() server_reader = AsyncMock() server_writer = MagicMock() server_writer.drain = AsyncMock() # Capture written data because client_buffer is cleared written_data = [] def capture_write(data): written_data.append(bytes(data)) # Convert bytearray to bytes copy server_writer.write.side_effect = capture_write # Invalid HTTP request (no spaces) invalid_request = b"INVALID_REQUEST\r\nHeader: val\r\n\r\n" client_reader.read.side_effect = [invalid_request, b""] server_reader.read.return_value = b"" # Server closes immediately await proxy_server._forward_data_with_interception( client_reader, client_writer, server_reader, server_writer, "example.com" ) # Should have forwarded raw buffer assert invalid_request in written_data @pytest.mark.asyncio async def test_forward_data_with_interception_partial_data(proxy_server, mock_deps): """Test forwarding with interception when data arrives in chunks.""" client_reader = AsyncMock() MagicMock() server_reader = AsyncMock() server_writer = MagicMock() server_writer.drain = AsyncMock() # Split request into chunks chunk1 = b"POST /path " chunk2 = b"HTTP/1.1\r\n" chunk3 = b"Host: example.com\r\n\r\nBody" client_reader.read.side_effect = [chunk1, chunk2, chunk3, b""] server_reader.read.return_value = b"" # Server closes immediately # Setup interceptor to avoid errors mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed") # See explanation in bug reproduction test pass @pytest.mark.asyncio async def test_forward_data_with_interception_split_headers_bug_reproduction( proxy_server, mock_deps ): """ Test that split headers cause interception to be skipped (or fail). This test documents current behavior which might be buggy. """ client_reader = AsyncMock() client_writer = MagicMock() server_reader = AsyncMock() server_writer = MagicMock() server_writer.drain = AsyncMock() chunk1 = b"POST /GenerateContent HTTP/1.1\r\nHost: e" chunk2 = b"xample.com\r\n\r\nBody" client_reader.read.side_effect = [chunk1, chunk2, b""] server_reader.read.return_value = b"" # Server closes immediately await proxy_server._forward_data_with_interception( client_reader, client_writer, server_reader, server_writer, "example.com" ) # Because of the potential bug, it will forward chunk1 immediately. server_writer.write.assert_any_call(chunk1) # Verify process_request was NOT called assert not mock_deps["interceptor"].process_request.called @pytest.mark.asyncio async def test_forward_data_with_interception_http_error_response( proxy_server, mock_deps ): """Test interception with HTTP error response (4xx/5xx). When the upstream returns an error status, we should: 1. Log the error 2. Send an error payload to the queue (fail-fast) """ import json client_reader = AsyncMock() client_writer = MagicMock() client_writer.write = MagicMock() server_reader = AsyncMock() server_writer = MagicMock() server_writer.drain = AsyncMock() # Setup a queue to capture error payload mock_queue = MagicMock() proxy_server.queue = mock_queue # Client sends a GenerateContent request client_data = b"POST /GenerateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody" client_reader.read.side_effect = [client_data, b""] # Server returns 429 Too Many Requests server_response = b"HTTP/1.1 429 Too Many Requests\r\nContent-Type: text/plain\r\n\r\nRate limited" server_reader.read.side_effect = [server_response, b""] # Setup interceptor mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed") await proxy_server._forward_data_with_interception( client_reader, client_writer, server_reader, server_writer, "example.com" ) # Verify error payload was sent to queue mock_queue.put.assert_called() call_args = mock_queue.put.call_args[0][0] parsed = json.loads(call_args) assert parsed["error"] is True assert parsed["status"] == 429 assert "429" in parsed["message"] assert parsed["done"] is True @pytest.mark.asyncio async def test_handle_connect_connection_error_with_interception( proxy_server, mock_deps ): """Test CONNECT handling when connection to server fails (with interception). When create_connection fails after TLS upgrade, we should: 1. Log the error 2. Close the client writer properly """ mock_reader = AsyncMock() mock_writer = MagicMock() mock_writer.drain = AsyncMock() mock_writer.close = MagicMock() mock_writer.wait_closed = AsyncMock() # Setup for interception proxy_server.should_intercept = MagicMock(return_value=True) mock_transport = MagicMock() mock_writer.transport = mock_transport mock_loop = MagicMock() new_transport = MagicMock() mock_loop.start_tls = AsyncMock(return_value=new_transport) # Mock connector to raise connection error mock_deps["connector"].create_connection = AsyncMock( side_effect=ConnectionRefusedError("Connection refused") ) with ( patch("asyncio.get_running_loop", return_value=mock_loop), patch("ssl.create_default_context"), patch("asyncio.StreamWriter"), # Mock StreamWriter constructor ): await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443") # Should log the error proxy_server.logger.error.assert_called() error_call = str(proxy_server.logger.error.call_args) assert "example.com" in error_call @pytest.mark.asyncio async def test_handle_connect_connection_error_no_interception(proxy_server, mock_deps): """Test CONNECT handling when connection to server fails (no interception). When create_connection fails without interception, we should: 1. Log the error 2. Close the original writer properly """ mock_reader = AsyncMock() mock_writer = MagicMock() mock_writer.drain = AsyncMock() mock_writer.close = MagicMock() mock_writer.wait_closed = AsyncMock() # Setup for no interception proxy_server.should_intercept = MagicMock(return_value=False) # Mock connector to raise connection error mock_deps["connector"].create_connection = AsyncMock( side_effect=TimeoutError("Connection timed out") ) await proxy_server._handle_connect(mock_reader, mock_writer, "other.com:443") # Should log the error proxy_server.logger.error.assert_called() error_call = str(proxy_server.logger.error.call_args) assert "other.com" in error_call # Should close the writer mock_writer.close.assert_called() @pytest.mark.asyncio async def test_forward_data_with_interception_response_parsing_error( proxy_server, mock_deps ): """Test interception when response parsing fails. If the interceptor throws during response processing, we should: 1. Log the error 2. Continue forwarding data """ client_reader = AsyncMock() client_writer = MagicMock() client_writer.write = MagicMock() server_reader = AsyncMock() server_writer = MagicMock() server_writer.drain = AsyncMock() # Client sends a GenerateContent request client_data = b"POST /GenerateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody" client_reader.read.side_effect = [client_data, b""] # Server returns a valid 200 response server_response = ( b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nResponse body" ) server_reader.read.side_effect = [server_response, b""] # Setup interceptor to fail during response processing mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed") mock_deps["interceptor"].process_response = AsyncMock( side_effect=ValueError("Failed to parse response") ) await proxy_server._forward_data_with_interception( client_reader, client_writer, server_reader, server_writer, "example.com" ) # Should log the error proxy_server.logger.error.assert_called() error_call = str(proxy_server.logger.error.call_args) assert "interception" in error_call.lower() or "response" in error_call.lower()