File size: 14,404 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
"""
Tests for ProxyServer lifecycle and connection handling.

Covers:
- handle_client() - main client connection handler
- _handle_connect() - CONNECT method and SSL setup
- start() - server startup and READY signaling

These tests focus on the untested control flow and error handling paths.
"""

import asyncio
import multiprocessing
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from stream.proxy_server import ProxyServer

# ==================== FIXTURES ====================


@pytest.fixture
def mock_deps():
    """Mock all external dependencies."""
    with (
        patch("stream.proxy_server.CertificateManager") as MockCert,
        patch("stream.proxy_server.ProxyConnector") as MockConnector,
        patch("stream.proxy_server.HttpInterceptor") as MockInterceptor,
        patch("logging.getLogger"),
    ):
        mock_cert = MockCert.return_value
        mock_cert.cert_dir = MagicMock()
        mock_cert.cert_dir.__truediv__ = MagicMock(return_value="fake/path")
        mock_cert.get_domain_cert = MagicMock()

        mock_connector = MockConnector.return_value
        mock_connector.create_connection = AsyncMock()

        yield {
            "cert": mock_cert,
            "connector": mock_connector,
            "interceptor": MockInterceptor.return_value,
        }


@pytest.fixture
def proxy_server(mock_deps):
    """Create ProxyServer with mocked dependencies."""
    queue = multiprocessing.Queue()
    # Immediately call cancel_join_thread to prevent feeder thread from hanging the process
    queue.cancel_join_thread()
    server = ProxyServer(
        host="127.0.0.1", port=3120, intercept_domains=["*.google.com"], queue=queue
    )
    yield server
    # Explicitly close the queue to release resources
    queue.close()


# ==================== TESTS: handle_client ====================


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_client_with_connect_method(proxy_server):
    """Test handle_client processes CONNECT requests."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    # Mock CONNECT request
    reader.readline.return_value = b"CONNECT example.com:443 HTTP/1.1\r\n"

    # Mock _handle_connect to verify it's called
    with patch.object(
        proxy_server, "_handle_connect", new_callable=AsyncMock
    ) as mock_connect:
        await proxy_server.handle_client(reader, writer)

        # Verify _handle_connect was called
        mock_connect.assert_called_once_with(reader, writer, "example.com:443")


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_client_empty_request_line(proxy_server):
    """Test handle_client handles empty request (client disconnects)."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    # Empty request line (connection closed)
    reader.readline.return_value = b""

    await proxy_server.handle_client(reader, writer)

    # Verify connection was closed
    writer.close.assert_called()


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_client_exception_handling(proxy_server):
    """Test handle_client catches and logs exceptions."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    # Make readline raise an exception
    reader.readline.side_effect = Exception("Read error")

    # Should not crash, just log error
    await proxy_server.handle_client(reader, writer)

    # Verify logger was called
    proxy_server.logger.error.assert_called()


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_client_cancelled_error_propagates(proxy_server):
    """Test that CancelledError is re-raised."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    # Make readline raise CancelledError
    reader.readline.side_effect = asyncio.CancelledError()

    # Should re-raise CancelledError
    with pytest.raises(asyncio.CancelledError):
        await proxy_server.handle_client(reader, writer)


# ==================== TESTS: _handle_connect ====================


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_connect_non_intercepted_domain(proxy_server, mock_deps):
    """Test CONNECT to non-intercepted domain (passthrough)."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.write = MagicMock()
    writer.drain = AsyncMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    # Mock reader.read for dropping proxy headers
    reader.read = AsyncMock(return_value=b"")

    # Mock server connection
    server_reader = AsyncMock()
    server_writer = MagicMock()
    server_writer.close = MagicMock()
    server_writer.wait_closed = AsyncMock()

    mock_deps["connector"].create_connection.return_value = (
        server_reader,
        server_writer,
    )

    # Mock _forward_data
    with patch.object(
        proxy_server, "_forward_data", new_callable=AsyncMock
    ) as mock_forward:
        # Non-intercepted domain
        await proxy_server._handle_connect(reader, writer, "example.com:443")

        # Verify "200 Connection Established" was sent
        writer.write.assert_called_with(b"HTTP/1.1 200 Connection Established\r\n\r\n")

        # Verify forwarding was started (no interception)
        mock_forward.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_connect_intercepted_domain_ssl_setup(proxy_server, mock_deps):
    """Test CONNECT to intercepted domain sets up SSL/TLS."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.write = MagicMock()
    writer.drain = AsyncMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    # Mock transport for SSL upgrade
    mock_transport = MagicMock()
    mock_protocol = MagicMock()
    writer.transport = mock_transport
    mock_transport.get_protocol.return_value = mock_protocol

    # Mock reader.read for dropping headers
    reader.read = AsyncMock(return_value=b"")

    # Mock event loop start_tls
    mock_new_transport = MagicMock()
    with (
        patch("asyncio.get_running_loop") as mock_get_loop,
        patch("ssl.create_default_context") as mock_ssl_ctx,
    ):
        mock_loop = mock_get_loop.return_value
        mock_loop.start_tls = AsyncMock(return_value=mock_new_transport)

        # Mock SSL context to avoid file loading
        mock_ctx_instance = MagicMock()
        mock_ctx_instance.load_cert_chain = MagicMock()
        mock_ssl_ctx.return_value = mock_ctx_instance

        # Mock server connection
        server_reader = AsyncMock()
        server_writer = MagicMock()
        mock_deps["connector"].create_connection.return_value = (
            server_reader,
            server_writer,
        )

        # Mock StreamWriter creation to avoid asyncio internal assertions
        with (
            patch("asyncio.StreamWriter") as mock_stream_writer_cls,
            patch.object(
                proxy_server, "_forward_data_with_interception", new_callable=AsyncMock
            ) as mock_intercept,
        ):
            # Create a fake StreamWriter instance
            mock_stream_writer_instance = MagicMock()
            mock_stream_writer_cls.return_value = mock_stream_writer_instance

            # Intercepted domain
            await proxy_server._handle_connect(
                reader, writer, "aistudio.google.com:443"
            )

            # Verify cert was generated
            mock_deps["cert"].get_domain_cert.assert_called_with("aistudio.google.com")

            # Verify SSL context load_cert_chain was called
            mock_ctx_instance.load_cert_chain.assert_called_once()

            # Verify TLS upgrade happened
            mock_loop.start_tls.assert_called_once()

            # Verify interception forwarding was started
            mock_intercept.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_connect_transport_none_before_tls(proxy_server):
    """Test _handle_connect handles None transport gracefully."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.write = MagicMock()
    writer.drain = AsyncMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()
    writer.transport = None  # Transport is None

    reader.read = AsyncMock(return_value=b"")

    # Intercepted domain but transport is None
    await proxy_server._handle_connect(reader, writer, "aistudio.google.com:443")

    # Verify warning was logged
    proxy_server.logger.warning.assert_called()
    assert "transport is None" in str(proxy_server.logger.warning.call_args)


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_connect_start_tls_returns_none(proxy_server, mock_deps):
    """Test _handle_connect handles start_tls returning None."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.write = MagicMock()
    writer.drain = AsyncMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    mock_transport = MagicMock()
    writer.transport = mock_transport
    mock_transport.get_protocol.return_value = MagicMock()

    reader.read = AsyncMock(return_value=b"")

    # Mock start_tls to return None (error case)
    with (
        patch("asyncio.get_running_loop") as mock_get_loop,
        patch("ssl.create_default_context") as mock_ssl_ctx,
    ):
        mock_loop = mock_get_loop.return_value
        mock_loop.start_tls = AsyncMock(return_value=None)

        # Mock SSL context to avoid file loading
        mock_ctx_instance = MagicMock()
        mock_ctx_instance.load_cert_chain = MagicMock()
        mock_ssl_ctx.return_value = mock_ctx_instance

        await proxy_server._handle_connect(reader, writer, "aistudio.google.com:443")

        # Verify error was logged
        proxy_server.logger.error.assert_called()
        assert "start_tls returned None" in str(proxy_server.logger.error.call_args)

        # Verify connection was closed
        writer.close.assert_called()


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_handle_connect_server_connection_fails(proxy_server, mock_deps):
    """Test _handle_connect handles server connection errors."""
    reader = AsyncMock()
    writer = MagicMock()
    writer.write = MagicMock()
    writer.drain = AsyncMock()
    writer.close = MagicMock()
    writer.wait_closed = AsyncMock()

    reader.read = AsyncMock(return_value=b"")

    # Mock connection failure
    mock_deps["connector"].create_connection.side_effect = Exception(
        "Connection refused"
    )

    # Non-intercepted domain with connection failure
    await proxy_server._handle_connect(reader, writer, "example.com:443")

    # Verify error was logged
    proxy_server.logger.error.assert_called()

    # Verify writer was closed
    writer.close.assert_called()


# ==================== TESTS: start() ====================


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_start_creates_server_and_signals_ready(proxy_server):
    """Test start() creates server and sends READY signal to queue."""
    mock_server = MagicMock()
    mock_socket = MagicMock()
    mock_socket.getsockname.return_value = ("127.0.0.1", 3120)
    mock_server.sockets = [mock_socket]
    mock_server.serve_forever = AsyncMock()
    mock_server.__aenter__ = AsyncMock(return_value=mock_server)
    mock_server.__aexit__ = AsyncMock(return_value=None)

    with patch("asyncio.start_server", new_callable=AsyncMock) as mock_start_server:
        mock_start_server.return_value = mock_server

        # Create task for start() since it runs forever
        task = asyncio.create_task(proxy_server.start())

        # Give it time to start
        await asyncio.sleep(0.2)

        # Cancel the task
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        # Verify server was created
        mock_start_server.assert_called_once()

        # Verify READY signal was sent to queue
        # Note: queue.put is called, we can't easily verify without real queue


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_start_handles_queue_none(mock_deps):
    """Test start() works when queue is None (no signaling)."""
    # Create server without queue
    server = ProxyServer(
        host="127.0.0.1",
        port=3120,
        intercept_domains=["*.google.com"],
        queue=None,  # No queue
    )

    mock_server = MagicMock()
    mock_socket = MagicMock()
    mock_socket.getsockname.return_value = ("127.0.0.1", 3120)
    mock_server.sockets = [mock_socket]
    mock_server.serve_forever = AsyncMock()
    mock_server.__aenter__ = AsyncMock(return_value=mock_server)
    mock_server.__aexit__ = AsyncMock(return_value=None)

    with patch("asyncio.start_server", new_callable=AsyncMock) as mock_start_server:
        mock_start_server.return_value = mock_server

        # Create task for start()
        task = asyncio.create_task(server.start())

        # Give it time to start
        await asyncio.sleep(0.2)

        # Cancel
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        # Should not crash even without queue
        mock_start_server.assert_called_once()


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_start_logs_server_address(proxy_server):
    """Test start() logs the server address."""
    mock_server = MagicMock()
    mock_socket = MagicMock()
    mock_socket.getsockname.return_value = ("127.0.0.1", 3120)
    mock_server.sockets = [mock_socket]
    mock_server.serve_forever = AsyncMock()
    mock_server.__aenter__ = AsyncMock(return_value=mock_server)
    mock_server.__aexit__ = AsyncMock(return_value=None)

    with patch("asyncio.start_server", new_callable=AsyncMock) as mock_start_server:
        mock_start_server.return_value = mock_server

        task = asyncio.create_task(proxy_server.start())

        await asyncio.sleep(0.2)

        # Verify logger.debug was called with address (implementation uses debug, not info)
        proxy_server.logger.debug.assert_called()

        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass