Spaces:
Paused
Paused
| import asyncio | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from stream.proxy_server import ProxyServer | |
| class TestProxyServer: | |
| def mock_cert_manager(self): | |
| with patch("stream.proxy_server.CertificateManager") as mock: | |
| instance = mock.return_value | |
| # Mock cert dir path | |
| instance.cert_dir = MagicMock() | |
| instance.cert_dir.__truediv__.return_value = "path/to/cert" | |
| yield instance | |
| def mock_connector(self): | |
| with patch("stream.proxy_server.ProxyConnector") as mock: | |
| instance = mock.return_value | |
| # Make create_connection async and return tuple (reader, writer) | |
| instance.create_connection = AsyncMock() | |
| instance.create_connection.return_value = (AsyncMock(), MagicMock()) | |
| yield instance | |
| def mock_interceptor(self): | |
| with patch("stream.proxy_server.HttpInterceptor") as mock: | |
| yield mock.return_value | |
| def mock_path(self): | |
| with patch("stream.proxy_server.Path") as mock: | |
| yield mock | |
| def server(self, mock_cert_manager, mock_connector, mock_interceptor, mock_path): | |
| return ProxyServer(intercept_domains=["example.com", "*.google.com"]) | |
| def test_should_intercept(self, server): | |
| """Test domain interception matching (exact, wildcard, subdomain logic).""" | |
| # Exact match | |
| assert server.should_intercept("example.com") | |
| # No match | |
| assert not server.should_intercept("other.com") | |
| # Wildcard match | |
| assert server.should_intercept("mail.google.com") | |
| # Wildcard logic: d[1:] is ".google.com". "google.com" ends with ".google.com"? | |
| # "google.com" does NOT end with ".google.com". | |
| # So it matches subdomains only. | |
| assert not server.should_intercept("google.com") | |
| def mock_writer(self): | |
| writer = MagicMock() | |
| writer.drain = AsyncMock() | |
| writer.wait_closed = AsyncMock() | |
| writer.close = MagicMock() | |
| return writer | |
| async def test_handle_client_connect_intercept(self, server, mock_writer): | |
| # Setup mocks | |
| reader = AsyncMock() | |
| writer = mock_writer | |
| # Mock request line | |
| reader.readline.return_value = b"CONNECT example.com:443 HTTP/1.1" | |
| # Mock _handle_connect to verify it's called | |
| with patch.object( | |
| server, "_handle_connect", new_callable=AsyncMock | |
| ) as mock_handle_connect: | |
| await server.handle_client(reader, writer) | |
| mock_handle_connect.assert_called_once_with( | |
| reader, writer, "example.com:443" | |
| ) | |
| async def test_handle_client_not_connect(self, server, mock_writer): | |
| reader = AsyncMock() | |
| writer = mock_writer | |
| # Non-CONNECT method | |
| reader.readline.return_value = b"GET http://example.com/ HTTP/1.1" | |
| await server.handle_client(reader, writer) | |
| # Verify writer closed | |
| writer.close.assert_called() | |
| async def test_handle_client_empty_request(self, server, mock_writer): | |
| reader = AsyncMock() | |
| writer = mock_writer | |
| # Empty request line | |
| reader.readline.return_value = b"" | |
| await server.handle_client(reader, writer) | |
| writer.close.assert_called() | |
| async def test_handle_connect_no_intercept( | |
| self, server, mock_connector, mock_writer | |
| ): | |
| # intercept_domains does not include example.org | |
| reader = AsyncMock() | |
| writer = mock_writer | |
| # Mock _forward_data | |
| with patch.object( | |
| server, "_forward_data", new_callable=AsyncMock | |
| ) as mock_forward: | |
| await server._handle_connect(reader, writer, "example.org:443") | |
| # Verify 200 OK sent | |
| writer.write.assert_called_with( | |
| b"HTTP/1.1 200 Connection Established\r\n\r\n" | |
| ) | |
| # Verify connection to upstream (no SSL) | |
| mock_connector.create_connection.assert_called_with( | |
| "example.org", 443, ssl=None | |
| ) | |
| # Verify forward called | |
| mock_forward.assert_called_once() | |
| async def test_handle_connect_intercept_flow( | |
| self, server, mock_cert_manager, mock_connector, mock_writer | |
| ): | |
| # Target example.com (in intercept list) | |
| reader = AsyncMock() | |
| writer = mock_writer | |
| # Mock transport for TLS upgrade | |
| transport = MagicMock() | |
| writer.transport = transport | |
| transport.get_protocol.return_value = MagicMock() | |
| # Mock loop.start_tls | |
| loop = MagicMock() | |
| loop.start_tls = AsyncMock(return_value="new_transport") | |
| # Mock asyncio.StreamWriter to return a mock | |
| mock_client_writer = MagicMock() | |
| mock_client_writer.wait_closed = AsyncMock() | |
| mock_client_writer.close = MagicMock() | |
| with ( | |
| patch("asyncio.get_running_loop", return_value=loop), | |
| patch.object( | |
| server, "_forward_data_with_interception", new_callable=AsyncMock | |
| ) as mock_forward_intercept, | |
| patch("ssl.create_default_context"), | |
| patch("asyncio.StreamWriter", return_value=mock_client_writer), | |
| ): | |
| await server._handle_connect(reader, writer, "example.com:443") | |
| # Verify cert generation | |
| mock_cert_manager.get_domain_cert.assert_called_with("example.com") | |
| # Verify TLS upgrade | |
| loop.start_tls.assert_called() | |
| # Verify upstream connection with SSL | |
| mock_connector.create_connection.assert_called() | |
| args, kwargs = mock_connector.create_connection.call_args | |
| assert kwargs["ssl"] is not None | |
| # Verify interception forwarder called | |
| mock_forward_intercept.assert_called_once() | |
| async def test_forward_data_basic(self, server, mock_writer): | |
| # Test _forward_data simple flow | |
| c_reader = AsyncMock() | |
| c_writer = mock_writer | |
| s_reader = AsyncMock() | |
| s_writer = MagicMock() | |
| s_writer.drain = AsyncMock() | |
| s_writer.wait_closed = AsyncMock() | |
| s_writer.close = MagicMock() | |
| # Mock read to return data then EOF | |
| c_reader.read.side_effect = [b"data1", b""] | |
| # s_reader is slow so it will be cancelled when c_reader finishes | |
| async def slow_read(*args, **kwargs): | |
| await asyncio.sleep(0.1) | |
| return b"" | |
| s_reader.read.side_effect = slow_read | |
| await server._forward_data(c_reader, c_writer, s_reader, s_writer) | |
| # Verify writes | |
| # client_to_server task reads c_reader and writes to s_writer | |
| s_writer.write.assert_called_with(b"data1") | |
| # server_to_client task reads s_reader (slow) and writes to c_writer | |
| # Since it's slow, it might not have written anything before cancellation | |
| # c_writer.write.assert_called_with(b"data2") | |
| # Verify closes | |
| c_writer.close.assert_called() | |
| s_writer.close.assert_called() | |
| def test_should_intercept_wildcard(self, server): | |
| """Test wildcard domain interception (matches subdomains only).""" | |
| server.intercept_domains = ["*.example.com"] | |
| assert server.should_intercept("sub.example.com") is True | |
| assert server.should_intercept("example.com") is False | |
| assert server.should_intercept("other.com") is False | |
| async def test_handle_client_cancellation(self, server, mock_writer): | |
| mock_reader = AsyncMock() | |
| mock_reader.readline.side_effect = asyncio.CancelledError() | |
| with pytest.raises(asyncio.CancelledError): | |
| await server.handle_client(mock_reader, mock_writer) | |
| async def test_forward_data_cancellation(self, server, mock_writer): | |
| # Test cancellation for basic forward data | |
| c_reader = AsyncMock() | |
| c_writer = mock_writer | |
| s_reader = AsyncMock() | |
| s_writer = MagicMock() | |
| s_writer.drain = AsyncMock() | |
| s_writer.wait_closed = AsyncMock() | |
| s_writer.close = MagicMock() | |
| # Define slow read | |
| async def slow_read(*args, **kwargs): | |
| await asyncio.sleep(2) | |
| return b"" | |
| c_reader.read.side_effect = slow_read | |
| s_reader.read.side_effect = slow_read | |
| task = asyncio.create_task( | |
| server._forward_data(c_reader, c_writer, s_reader, s_writer) | |
| ) | |
| await asyncio.sleep(0.1) | |
| task.cancel() | |
| with pytest.raises(asyncio.CancelledError): | |
| await task | |
| async def test_handle_client_close_error(self, server, mock_writer): | |
| # Test error during writer.wait_closed | |
| mock_reader = AsyncMock() | |
| mock_reader.readline.return_value = b"" # Empty request triggers close | |
| mock_writer.wait_closed.side_effect = Exception("Close error") | |
| await server.handle_client(mock_reader, mock_writer) | |
| # Should not raise exception | |
| mock_writer.close.assert_called() | |
| async def test_server_start_queue_ready(self, server): | |
| mock_queue = MagicMock() | |
| server.queue = mock_queue | |
| # Mock asyncio.start_server to return a mock server | |
| mock_server = AsyncMock() | |
| mock_server.sockets = [MagicMock()] | |
| mock_server.sockets[0].getsockname.return_value = ("127.0.0.1", 8080) | |
| # We need to mock serve_forever to stop immediately or throw exception to exit | |
| mock_server.serve_forever.side_effect = asyncio.CancelledError() | |
| with patch("asyncio.start_server", return_value=mock_server): | |
| try: | |
| await server.start() | |
| except asyncio.CancelledError: | |
| pass | |
| mock_queue.put.assert_called_with("READY") | |
| async def test_forward_data_with_interception_flow(self, server, mock_writer): | |
| # Setup mocks for interception flow | |
| client_reader = AsyncMock() | |
| client_writer = mock_writer | |
| server_reader = AsyncMock() | |
| server_writer = MagicMock() | |
| server_writer.drain = AsyncMock() | |
| server_writer.wait_closed = AsyncMock() | |
| server_writer.close = MagicMock() | |
| # Mock client sending a request to be intercepted | |
| request_data = ( | |
| b"POST /generateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody" | |
| ) | |
| client_reader.read.side_effect = [request_data, b""] | |
| # Mock server sending a response | |
| response_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{}" | |
| server_reader.read.side_effect = [response_data, b""] | |
| # Mock interceptor | |
| server.interceptor.process_request = AsyncMock(return_value=b"Body") | |
| server.interceptor.process_response = AsyncMock(return_value={"data": "test"}) | |
| # Mock queue | |
| server.queue = MagicMock() | |
| await server._forward_data_with_interception( | |
| client_reader, client_writer, server_reader, server_writer, "example.com" | |
| ) | |
| # Verify interception occurred | |
| server.interceptor.process_request.assert_called() | |
| server.interceptor.process_response.assert_called() | |
| server.queue.put.assert_called() | |
| async def test_forward_data_with_interception_flow_slow_cancellation( | |
| self, server, mock_writer | |
| ): | |
| # Test cancellation of pending task when one side finishes first | |
| client_reader = AsyncMock() | |
| client_writer = mock_writer | |
| server_reader = AsyncMock() | |
| server_writer = MagicMock() | |
| server_writer.drain = AsyncMock() | |
| server_writer.wait_closed = AsyncMock() | |
| server_writer.close = MagicMock() | |
| # Client sends request fast | |
| request_data = ( | |
| b"POST /generateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody" | |
| ) | |
| client_reader.read.side_effect = [request_data, b""] | |
| # Server is slow | |
| async def slow_read(*args, **kwargs): | |
| await asyncio.sleep(0.1) | |
| return b"" | |
| server_reader.read.side_effect = slow_read | |
| server.interceptor.process_request = AsyncMock(return_value=b"Body") | |
| server.queue = MagicMock() | |
| await server._forward_data_with_interception( | |
| client_reader, client_writer, server_reader, server_writer, "example.com" | |
| ) | |
| # This should complete without error, and server reader task should be cancelled internally | |
| server.interceptor.process_request.assert_called() | |
| async def test_forward_data_with_interception_no_sniff(self, server, mock_writer): | |
| # Path does not contain GenerateContent | |
| client_reader = AsyncMock() | |
| client_writer = mock_writer | |
| server_reader = AsyncMock() | |
| server_writer = MagicMock() | |
| server_writer.drain = AsyncMock() | |
| server_writer.wait_closed = AsyncMock() | |
| server_writer.close = MagicMock() | |
| # Capture written data because mock stores reference to mutable bytearray | |
| written_data = [] | |
| def capture_write(data): | |
| written_data.append(bytes(data)) | |
| server_writer.write.side_effect = capture_write | |
| request_data = b"POST /other/path HTTP/1.1\r\nHost: example.com\r\n\r\nBody" | |
| client_reader.read.side_effect = [request_data, b""] | |
| server_reader.read.side_effect = [ | |
| b"", | |
| b"", | |
| ] # No response needed for this test part | |
| server.interceptor.process_request = AsyncMock() | |
| await server._forward_data_with_interception( | |
| client_reader, client_writer, server_reader, server_writer, "example.com" | |
| ) | |
| # Should not call process_request | |
| server.interceptor.process_request.assert_not_called() | |
| # Should write original buffer | |
| assert b"".join(written_data) == request_data | |
| async def test_forward_data_with_interception_cancellation( | |
| self, server, mock_writer | |
| ): | |
| client_reader = AsyncMock() | |
| client_writer = mock_writer | |
| server_reader = AsyncMock() | |
| server_writer = MagicMock() | |
| server_writer.drain = AsyncMock() | |
| server_writer.wait_closed = AsyncMock() | |
| server_writer.close = MagicMock() | |
| # Slow read to allow cancellation | |
| async def slow_read(*args, **kwargs): | |
| await asyncio.sleep(2) | |
| return b"" | |
| client_reader.read.side_effect = slow_read | |
| server_reader.read.side_effect = slow_read | |
| task = asyncio.create_task( | |
| server._forward_data_with_interception( | |
| client_reader, | |
| client_writer, | |
| server_reader, | |
| server_writer, | |
| "example.com", | |
| ) | |
| ) | |
| await asyncio.sleep(0.1) | |
| task.cancel() | |
| with pytest.raises(asyncio.CancelledError): | |
| await task | |
| async def test_start_server_queue_error(self, server): | |
| mock_queue = MagicMock() | |
| server.queue = mock_queue | |
| mock_queue.put.side_effect = Exception("Queue error") | |
| mock_server = AsyncMock() | |
| mock_server.sockets = [MagicMock()] | |
| mock_server.sockets[0].getsockname.return_value = ("127.0.0.1", 8080) | |
| mock_server.serve_forever.side_effect = asyncio.CancelledError() | |
| with patch("asyncio.start_server", return_value=mock_server): | |
| try: | |
| await server.start() | |
| except asyncio.CancelledError: | |
| pass | |
| # Should log error but not crash before serve_forever | |
| # If it crashed, serve_forever wouldn't be called (but we mocked it to raise CancelledError) | |
| # We can check if logger.error was called | |
| # But we didn't mock logger in fixture explicitly, it uses real logger or default. | |
| # Let's check if queue.put was called. | |
| mock_queue.put.assert_called_with("READY") | |