Spaces:
Paused
Paused
| import argparse | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from stream.main import builtin, main, parse_args | |
| from stream.proxy_server import ProxyServer | |
| def proxy_server(): | |
| return ProxyServer( | |
| host="127.0.0.1", port=8080, intercept_domains=["*.example.com", "api.test.com"] | |
| ) | |
| def test_should_intercept(proxy_server): | |
| """Test domain interception matching (exact, wildcard, and non-matching cases).""" | |
| # Exact match | |
| assert proxy_server.should_intercept("api.test.com") is True | |
| # Wildcard match | |
| assert proxy_server.should_intercept("sub.example.com") is True | |
| assert proxy_server.should_intercept("deep.sub.example.com") is True | |
| # No match | |
| assert proxy_server.should_intercept("other.com") is False | |
| assert proxy_server.should_intercept("example.org") is False | |
| # Edge case: wildcard logic check | |
| # "*.example.com" -> suffix ".example.com" | |
| # "example.com" does not end with ".example.com" | |
| assert proxy_server.should_intercept("example.com") is False | |
| async def test_handle_client_connect(proxy_server): | |
| reader = AsyncMock() | |
| writer = AsyncMock() | |
| reader.readline.return_value = b"CONNECT api.test.com:443 HTTP/1.1\r\n" | |
| with patch.object(proxy_server, "_handle_connect", AsyncMock()) as mock_handle: | |
| await proxy_server.handle_client(reader, writer) | |
| mock_handle.assert_called_once() | |
| args = mock_handle.call_args[0] | |
| assert args[2] == "api.test.com:443" | |
| writer.close.assert_called_once() | |
| async def test_handle_client_empty(proxy_server): | |
| reader = AsyncMock() | |
| writer = AsyncMock() | |
| reader.readline.return_value = b"" | |
| await proxy_server.handle_client(reader, writer) | |
| writer.close.assert_called() | |
| async def test_handle_client_exception(proxy_server): | |
| reader = AsyncMock() | |
| writer = AsyncMock() | |
| reader.readline.side_effect = Exception("Read error") | |
| # Should log error but not raise (unless CancelledError) | |
| await proxy_server.handle_client(reader, writer) | |
| writer.close.assert_called() | |
| async def test_handle_connect_intercept(proxy_server): | |
| reader = AsyncMock() | |
| writer = AsyncMock() | |
| target = "api.test.com:443" | |
| # Mock cert manager | |
| proxy_server.cert_manager.get_domain_cert = MagicMock( | |
| return_value=("cert_path", "key_path") | |
| ) | |
| # Mock loop and start_tls | |
| mock_loop = MagicMock() | |
| mock_transport = MagicMock() | |
| mock_loop.start_tls = AsyncMock(return_value=mock_transport) | |
| # Mock ssl context | |
| mock_ssl_ctx = MagicMock() | |
| with ( | |
| patch("asyncio.get_running_loop", return_value=mock_loop), | |
| patch("ssl.create_default_context", return_value=mock_ssl_ctx), | |
| patch.object( | |
| proxy_server.proxy_connector, "create_connection", AsyncMock() | |
| ) as mock_create_conn, | |
| patch.object( | |
| proxy_server, "_forward_data_with_interception", AsyncMock() | |
| ) as mock_forward, | |
| patch("asyncio.StreamWriter", MagicMock()), | |
| ): | |
| mock_create_conn.return_value = (AsyncMock(), AsyncMock()) | |
| await proxy_server._handle_connect(reader, writer, target) | |
| proxy_server.cert_manager.get_domain_cert.assert_called_with("api.test.com") | |
| mock_ssl_ctx.load_cert_chain.assert_called() | |
| writer.write.assert_called() | |
| mock_create_conn.assert_called() | |
| mock_forward.assert_called() | |
| async def test_handle_connect_no_intercept(proxy_server): | |
| reader = AsyncMock() | |
| writer = AsyncMock() | |
| target = "other.com:443" | |
| with ( | |
| patch.object( | |
| proxy_server.proxy_connector, "create_connection", AsyncMock() | |
| ) as mock_create_conn, | |
| patch.object(proxy_server, "_forward_data", AsyncMock()) as mock_forward, | |
| patch("asyncio.gather", AsyncMock()), | |
| ): | |
| mock_create_conn.return_value = (AsyncMock(), AsyncMock()) | |
| await proxy_server._handle_connect(reader, writer, target) | |
| # Should call proxy_connector.create_connection | |
| mock_create_conn.assert_called_with("other.com", 443, ssl=None) | |
| # _forward_data is called once with 4 arguments | |
| assert mock_forward.call_count == 1 | |
| def test_parse_args(): | |
| """Test command-line argument parsing for proxy configuration.""" | |
| with patch("argparse.ArgumentParser.parse_args") as mock_parse: | |
| mock_parse.return_value = argparse.Namespace( | |
| host="1.2.3.4", port=9999, domains=["*.test.com"], proxy=None | |
| ) | |
| args = parse_args() | |
| assert args.host == "1.2.3.4" | |
| assert args.port == 9999 | |
| async def test_main(): | |
| with ( | |
| patch("stream.main.parse_args") as mock_parse, | |
| patch("stream.main.ProxyServer") as mock_cls, | |
| patch("pathlib.Path.mkdir", MagicMock()), | |
| ): | |
| mock_parse.return_value = argparse.Namespace( | |
| host="127.0.0.1", port=3120, domains=["*.google.com"], proxy=None | |
| ) | |
| mock_instance = AsyncMock() | |
| mock_cls.return_value = mock_instance | |
| await main() | |
| mock_cls.assert_called_once() | |
| mock_instance.start.assert_called_once() | |
| async def test_builtin(): | |
| with ( | |
| patch("stream.main.ProxyServer") as mock_cls, | |
| patch("pathlib.Path.mkdir", MagicMock()), | |
| ): | |
| mock_instance = AsyncMock() | |
| mock_cls.return_value = mock_instance | |
| await builtin(port=1234) | |
| mock_cls.assert_called_once() | |
| assert mock_cls.call_args[1]["port"] == 1234 | |
| mock_instance.start.assert_called_once() | |