File size: 16,391 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from stream.proxy_server import ProxyServer


class TestProxyServer:
    @pytest.fixture
    def mock_cert_manager(self):
        with patch("stream.proxy_server.CertificateManager") as mock:
            instance = mock.return_value
            # Mock cert dir path
            instance.cert_dir = MagicMock()
            instance.cert_dir.__truediv__.return_value = "path/to/cert"
            yield instance

    @pytest.fixture
    def mock_connector(self):
        with patch("stream.proxy_server.ProxyConnector") as mock:
            instance = mock.return_value
            # Make create_connection async and return tuple (reader, writer)
            instance.create_connection = AsyncMock()
            instance.create_connection.return_value = (AsyncMock(), MagicMock())
            yield instance

    @pytest.fixture
    def mock_interceptor(self):
        with patch("stream.proxy_server.HttpInterceptor") as mock:
            yield mock.return_value

    @pytest.fixture
    def mock_path(self):
        with patch("stream.proxy_server.Path") as mock:
            yield mock

    @pytest.fixture
    def server(self, mock_cert_manager, mock_connector, mock_interceptor, mock_path):
        return ProxyServer(intercept_domains=["example.com", "*.google.com"])

    def test_should_intercept(self, server):
        """Test domain interception matching (exact, wildcard, subdomain logic)."""
        # Exact match
        assert server.should_intercept("example.com")
        # No match
        assert not server.should_intercept("other.com")
        # Wildcard match
        assert server.should_intercept("mail.google.com")
        # Wildcard logic: d[1:] is ".google.com". "google.com" ends with ".google.com"?
        # "google.com" does NOT end with ".google.com".
        # So it matches subdomains only.
        assert not server.should_intercept("google.com")

    @pytest.fixture
    def mock_writer(self):
        writer = MagicMock()
        writer.drain = AsyncMock()
        writer.wait_closed = AsyncMock()
        writer.close = MagicMock()
        return writer

    @pytest.mark.asyncio
    async def test_handle_client_connect_intercept(self, server, mock_writer):
        # Setup mocks
        reader = AsyncMock()
        writer = mock_writer

        # Mock request line
        reader.readline.return_value = b"CONNECT example.com:443 HTTP/1.1"

        # Mock _handle_connect to verify it's called
        with patch.object(
            server, "_handle_connect", new_callable=AsyncMock
        ) as mock_handle_connect:
            await server.handle_client(reader, writer)

            mock_handle_connect.assert_called_once_with(
                reader, writer, "example.com:443"
            )

    @pytest.mark.asyncio
    async def test_handle_client_not_connect(self, server, mock_writer):
        reader = AsyncMock()
        writer = mock_writer

        # Non-CONNECT method
        reader.readline.return_value = b"GET http://example.com/ HTTP/1.1"

        await server.handle_client(reader, writer)

        # Verify writer closed
        writer.close.assert_called()

    @pytest.mark.asyncio
    async def test_handle_client_empty_request(self, server, mock_writer):
        reader = AsyncMock()
        writer = mock_writer

        # Empty request line
        reader.readline.return_value = b""

        await server.handle_client(reader, writer)

        writer.close.assert_called()

    @pytest.mark.asyncio
    async def test_handle_connect_no_intercept(
        self, server, mock_connector, mock_writer
    ):
        # intercept_domains does not include example.org
        reader = AsyncMock()
        writer = mock_writer

        # Mock _forward_data
        with patch.object(
            server, "_forward_data", new_callable=AsyncMock
        ) as mock_forward:
            await server._handle_connect(reader, writer, "example.org:443")

            # Verify 200 OK sent
            writer.write.assert_called_with(
                b"HTTP/1.1 200 Connection Established\r\n\r\n"
            )

            # Verify connection to upstream (no SSL)
            mock_connector.create_connection.assert_called_with(
                "example.org", 443, ssl=None
            )

            # Verify forward called
            mock_forward.assert_called_once()

    @pytest.mark.asyncio
    async def test_handle_connect_intercept_flow(
        self, server, mock_cert_manager, mock_connector, mock_writer
    ):
        # Target example.com (in intercept list)
        reader = AsyncMock()
        writer = mock_writer

        # Mock transport for TLS upgrade
        transport = MagicMock()
        writer.transport = transport
        transport.get_protocol.return_value = MagicMock()

        # Mock loop.start_tls
        loop = MagicMock()
        loop.start_tls = AsyncMock(return_value="new_transport")

        # Mock asyncio.StreamWriter to return a mock
        mock_client_writer = MagicMock()
        mock_client_writer.wait_closed = AsyncMock()
        mock_client_writer.close = MagicMock()

        with (
            patch("asyncio.get_running_loop", return_value=loop),
            patch.object(
                server, "_forward_data_with_interception", new_callable=AsyncMock
            ) as mock_forward_intercept,
            patch("ssl.create_default_context"),
            patch("asyncio.StreamWriter", return_value=mock_client_writer),
        ):
            await server._handle_connect(reader, writer, "example.com:443")

            # Verify cert generation
            mock_cert_manager.get_domain_cert.assert_called_with("example.com")

            # Verify TLS upgrade
            loop.start_tls.assert_called()

            # Verify upstream connection with SSL
            mock_connector.create_connection.assert_called()
            args, kwargs = mock_connector.create_connection.call_args
            assert kwargs["ssl"] is not None

            # Verify interception forwarder called
            mock_forward_intercept.assert_called_once()

    @pytest.mark.asyncio
    async def test_forward_data_basic(self, server, mock_writer):
        # Test _forward_data simple flow
        c_reader = AsyncMock()
        c_writer = mock_writer
        s_reader = AsyncMock()

        s_writer = MagicMock()
        s_writer.drain = AsyncMock()
        s_writer.wait_closed = AsyncMock()
        s_writer.close = MagicMock()

        # Mock read to return data then EOF
        c_reader.read.side_effect = [b"data1", b""]

        # s_reader is slow so it will be cancelled when c_reader finishes
        async def slow_read(*args, **kwargs):
            await asyncio.sleep(0.1)
            return b""

        s_reader.read.side_effect = slow_read

        await server._forward_data(c_reader, c_writer, s_reader, s_writer)

        # Verify writes
        # client_to_server task reads c_reader and writes to s_writer
        s_writer.write.assert_called_with(b"data1")

        # server_to_client task reads s_reader (slow) and writes to c_writer
        # Since it's slow, it might not have written anything before cancellation
        # c_writer.write.assert_called_with(b"data2")

        # Verify closes
        c_writer.close.assert_called()
        s_writer.close.assert_called()

    def test_should_intercept_wildcard(self, server):
        """Test wildcard domain interception (matches subdomains only)."""
        server.intercept_domains = ["*.example.com"]
        assert server.should_intercept("sub.example.com") is True
        assert server.should_intercept("example.com") is False
        assert server.should_intercept("other.com") is False

    @pytest.mark.asyncio
    async def test_handle_client_cancellation(self, server, mock_writer):
        mock_reader = AsyncMock()
        mock_reader.readline.side_effect = asyncio.CancelledError()

        with pytest.raises(asyncio.CancelledError):
            await server.handle_client(mock_reader, mock_writer)

    @pytest.mark.asyncio
    async def test_forward_data_cancellation(self, server, mock_writer):
        # Test cancellation for basic forward data
        c_reader = AsyncMock()
        c_writer = mock_writer
        s_reader = AsyncMock()

        s_writer = MagicMock()
        s_writer.drain = AsyncMock()
        s_writer.wait_closed = AsyncMock()
        s_writer.close = MagicMock()

        # Define slow read
        async def slow_read(*args, **kwargs):
            await asyncio.sleep(2)
            return b""

        c_reader.read.side_effect = slow_read
        s_reader.read.side_effect = slow_read

        task = asyncio.create_task(
            server._forward_data(c_reader, c_writer, s_reader, s_writer)
        )

        await asyncio.sleep(0.1)
        task.cancel()

        with pytest.raises(asyncio.CancelledError):
            await task

    @pytest.mark.asyncio
    async def test_handle_client_close_error(self, server, mock_writer):
        # Test error during writer.wait_closed
        mock_reader = AsyncMock()
        mock_reader.readline.return_value = b""  # Empty request triggers close

        mock_writer.wait_closed.side_effect = Exception("Close error")

        await server.handle_client(mock_reader, mock_writer)

        # Should not raise exception
        mock_writer.close.assert_called()

    @pytest.mark.asyncio
    async def test_server_start_queue_ready(self, server):
        mock_queue = MagicMock()
        server.queue = mock_queue

        # Mock asyncio.start_server to return a mock server
        mock_server = AsyncMock()
        mock_server.sockets = [MagicMock()]
        mock_server.sockets[0].getsockname.return_value = ("127.0.0.1", 8080)

        # We need to mock serve_forever to stop immediately or throw exception to exit
        mock_server.serve_forever.side_effect = asyncio.CancelledError()

        with patch("asyncio.start_server", return_value=mock_server):
            try:
                await server.start()
            except asyncio.CancelledError:
                pass

        mock_queue.put.assert_called_with("READY")

    @pytest.mark.asyncio
    async def test_forward_data_with_interception_flow(self, server, mock_writer):
        # Setup mocks for interception flow
        client_reader = AsyncMock()
        client_writer = mock_writer
        server_reader = AsyncMock()
        server_writer = MagicMock()
        server_writer.drain = AsyncMock()
        server_writer.wait_closed = AsyncMock()
        server_writer.close = MagicMock()

        # Mock client sending a request to be intercepted
        request_data = (
            b"POST /generateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
        )
        client_reader.read.side_effect = [request_data, b""]

        # Mock server sending a response
        response_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{}"
        server_reader.read.side_effect = [response_data, b""]

        # Mock interceptor
        server.interceptor.process_request = AsyncMock(return_value=b"Body")
        server.interceptor.process_response = AsyncMock(return_value={"data": "test"})

        # Mock queue
        server.queue = MagicMock()

        await server._forward_data_with_interception(
            client_reader, client_writer, server_reader, server_writer, "example.com"
        )

        # Verify interception occurred
        server.interceptor.process_request.assert_called()
        server.interceptor.process_response.assert_called()
        server.queue.put.assert_called()

    @pytest.mark.asyncio
    async def test_forward_data_with_interception_flow_slow_cancellation(
        self, server, mock_writer
    ):
        # Test cancellation of pending task when one side finishes first
        client_reader = AsyncMock()
        client_writer = mock_writer
        server_reader = AsyncMock()
        server_writer = MagicMock()
        server_writer.drain = AsyncMock()
        server_writer.wait_closed = AsyncMock()
        server_writer.close = MagicMock()

        # Client sends request fast
        request_data = (
            b"POST /generateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
        )
        client_reader.read.side_effect = [request_data, b""]

        # Server is slow
        async def slow_read(*args, **kwargs):
            await asyncio.sleep(0.1)
            return b""

        server_reader.read.side_effect = slow_read

        server.interceptor.process_request = AsyncMock(return_value=b"Body")
        server.queue = MagicMock()

        await server._forward_data_with_interception(
            client_reader, client_writer, server_reader, server_writer, "example.com"
        )

        # This should complete without error, and server reader task should be cancelled internally
        server.interceptor.process_request.assert_called()

    @pytest.mark.asyncio
    async def test_forward_data_with_interception_no_sniff(self, server, mock_writer):
        # Path does not contain GenerateContent
        client_reader = AsyncMock()
        client_writer = mock_writer
        server_reader = AsyncMock()
        server_writer = MagicMock()
        server_writer.drain = AsyncMock()
        server_writer.wait_closed = AsyncMock()
        server_writer.close = MagicMock()

        # Capture written data because mock stores reference to mutable bytearray
        written_data = []

        def capture_write(data):
            written_data.append(bytes(data))

        server_writer.write.side_effect = capture_write

        request_data = b"POST /other/path HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
        client_reader.read.side_effect = [request_data, b""]
        server_reader.read.side_effect = [
            b"",
            b"",
        ]  # No response needed for this test part

        server.interceptor.process_request = AsyncMock()

        await server._forward_data_with_interception(
            client_reader, client_writer, server_reader, server_writer, "example.com"
        )

        # Should not call process_request
        server.interceptor.process_request.assert_not_called()
        # Should write original buffer
        assert b"".join(written_data) == request_data

    @pytest.mark.asyncio
    async def test_forward_data_with_interception_cancellation(
        self, server, mock_writer
    ):
        client_reader = AsyncMock()
        client_writer = mock_writer
        server_reader = AsyncMock()
        server_writer = MagicMock()
        server_writer.drain = AsyncMock()
        server_writer.wait_closed = AsyncMock()
        server_writer.close = MagicMock()

        # Slow read to allow cancellation
        async def slow_read(*args, **kwargs):
            await asyncio.sleep(2)
            return b""

        client_reader.read.side_effect = slow_read
        server_reader.read.side_effect = slow_read

        task = asyncio.create_task(
            server._forward_data_with_interception(
                client_reader,
                client_writer,
                server_reader,
                server_writer,
                "example.com",
            )
        )

        await asyncio.sleep(0.1)
        task.cancel()

        with pytest.raises(asyncio.CancelledError):
            await task

    @pytest.mark.asyncio
    async def test_start_server_queue_error(self, server):
        mock_queue = MagicMock()
        server.queue = mock_queue
        mock_queue.put.side_effect = Exception("Queue error")

        mock_server = AsyncMock()
        mock_server.sockets = [MagicMock()]
        mock_server.sockets[0].getsockname.return_value = ("127.0.0.1", 8080)
        mock_server.serve_forever.side_effect = asyncio.CancelledError()

        with patch("asyncio.start_server", return_value=mock_server):
            try:
                await server.start()
            except asyncio.CancelledError:
                pass

        # Should log error but not crash before serve_forever
        # If it crashed, serve_forever wouldn't be called (but we mocked it to raise CancelledError)
        # We can check if logger.error was called
        # But we didn't mock logger in fixture explicitly, it uses real logger or default.
        # Let's check if queue.put was called.
        mock_queue.put.assert_called_with("READY")