File size: 15,500 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
"""
Comprehensive tests for ProxyServer data forwarding logic.

Focuses on the critical untested paths:
- _forward_data() - bidirectional forwarding without interception
- _forward_data_with_interception() - HTTP parsing and interception

These functions represent ~210 lines of untested code (50% of proxy_server.py).
"""

import asyncio
import multiprocessing
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from stream.proxy_server import ProxyServer

# ==================== TEST HELPERS ====================


class AsyncStreamReader:
    """Fake StreamReader with real async behavior using queues."""

    def __init__(self):
        self.queue = asyncio.Queue()
        self.closed = False
        self._eof = False

    async def read(self, n: int) -> bytes:
        """Read up to n bytes. Returns empty bytes on EOF."""
        if self._eof or self.closed:
            return b""

        try:
            # Wait for data with timeout to prevent hanging tests
            data = await asyncio.wait_for(self.queue.get(), timeout=2.0)
            if data == b"":  # EOF marker
                self._eof = True
            return data
        except asyncio.TimeoutError:
            return b""
        except asyncio.CancelledError:
            raise

    def feed_data(self, data: bytes):
        """Feed data into the reader (simulates network receive)."""
        if not self.closed:
            self.queue.put_nowait(data)

    def feed_eof(self):
        """Signal EOF to the reader."""
        self.queue.put_nowait(b"")


class AsyncStreamWriter:
    """Fake StreamWriter that collects written data."""

    def __init__(self):
        self.data = bytearray()
        self.closed = False
        self.close_event = asyncio.Event()

    def write(self, data: bytes):
        """Write data (synchronous API like real StreamWriter)."""
        if not self.closed:
            self.data.extend(data)

    async def drain(self):
        """Drain written data (no-op for fake)."""
        await asyncio.sleep(0)  # Yield to event loop

    def close(self):
        """Close the writer."""
        self.closed = True
        self.close_event.set()

    async def wait_closed(self):
        """Wait for close to complete."""
        await self.close_event.wait()

    def get_data(self) -> bytes:
        """Get all data written so far."""
        return bytes(self.data)


def create_stream_pair():
    """Create a pair of connected fake streams for testing bidirectional flow."""
    reader = AsyncStreamReader()
    writer = AsyncStreamWriter()
    return reader, writer


# ==================== FIXTURES ====================


@pytest.fixture
def mock_cert_manager():
    """Mock CertificateManager."""
    with patch("stream.proxy_server.CertificateManager") as mock:
        instance = mock.return_value
        instance.cert_dir = MagicMock()
        instance.get_domain_cert = MagicMock()
        yield instance


@pytest.fixture
def mock_proxy_connector():
    """Mock ProxyConnector."""
    with patch("stream.proxy_server.ProxyConnector") as mock:
        instance = mock.return_value
        instance.create_connection = AsyncMock()
        yield instance


@pytest.fixture
def mock_interceptor():
    """Mock HttpInterceptor."""
    with patch("stream.proxy_server.HttpInterceptor") as mock:
        instance = mock.return_value
        instance.process_request = AsyncMock(side_effect=lambda data, *args: data)
        instance.process_response = AsyncMock(return_value={"text": "mocked response"})
        yield instance


@pytest.fixture
def proxy_server(mock_cert_manager, mock_proxy_connector, mock_interceptor):
    """Create ProxyServer instance with mocked dependencies."""
    with patch("logging.getLogger"):
        queue = multiprocessing.Queue()
        # Immediately call cancel_join_thread to prevent feeder thread from hanging the process
        queue.cancel_join_thread()
        server = ProxyServer(
            host="127.0.0.1", port=3120, intercept_domains=["*.google.com"], queue=queue
        )
        yield server
        # Explicitly close the queue to release resources
        queue.close()


# ==================== TESTS: _forward_data (No Interception) ====================


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_forward_data_bidirectional_success(proxy_server):
    """Test data flows from client to server and back without interception."""
    # Create fake streams
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Feed test data
    client_reader.feed_data(b"GET / HTTP/1.1\r\n\r\n")
    client_reader.feed_eof()

    server_reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\nHello")
    server_reader.feed_eof()

    # Run forwarding
    await proxy_server._forward_data(
        client_reader, client_writer, server_reader, server_writer
    )

    # Verify data was forwarded
    # Client -> Server direction
    server_data = server_writer.get_data()
    assert b"GET / HTTP/1.1" in server_data

    # Server -> Client direction
    client_data = client_writer.get_data()
    assert b"HTTP/1.1 200 OK" in client_data
    assert b"Hello" in client_data


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_forward_data_handles_client_disconnect(proxy_server):
    """Test that server connection closes when client disconnects."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Client sends data then disconnects
    client_reader.feed_data(b"Some data")
    client_reader.feed_eof()

    # Server keeps sending
    server_reader.feed_data(b"Response data")
    server_reader.feed_eof()

    await proxy_server._forward_data(
        client_reader, client_writer, server_reader, server_writer
    )

    # Verify both connections closed
    assert client_writer.closed
    assert server_writer.closed


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_forward_data_handles_task_cancellation(proxy_server):
    """Test that task cancellation properly cleans up both directions."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Create task for forwarding
    task = asyncio.create_task(
        proxy_server._forward_data(
            client_reader, client_writer, server_reader, server_writer
        )
    )

    # Let it start
    await asyncio.sleep(0.1)

    # Cancel the task
    task.cancel()

    # Verify cancellation raises
    with pytest.raises(asyncio.CancelledError):
        await task


# ==================== TESTS: _forward_data_with_interception ====================


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_detects_generate_content_path(
    proxy_server, mock_interceptor
):
    """Test that GenerateContent requests are detected and intercepted."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Create HTTP POST request with GenerateContent path
    http_request = (
        b"POST /v1/models/gemini-1.5-pro:generateContent HTTP/1.1\r\n"
        b"Host: generativelanguage.googleapis.com\r\n"
        b"Content-Length: 50\r\n"
        b"\r\n"
        b'{"contents":[{"parts":[{"text":"Hello"}]}]}'
    )

    client_reader.feed_data(http_request)
    client_reader.feed_eof()

    # Server responds
    server_reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n{}")
    server_reader.feed_eof()

    # Run interception
    await proxy_server._forward_data_with_interception(
        client_reader,
        client_writer,
        server_reader,
        server_writer,
        host="generativelanguage.googleapis.com",
    )

    # Verify interceptor was called for request
    mock_interceptor.process_request.assert_called()
    call_args = mock_interceptor.process_request.call_args[0]
    request_body = call_args[0]
    assert b'{"contents"' in request_body


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_skips_non_generate_content_requests(
    proxy_server, mock_interceptor
):
    """Test that non-GenerateContent requests are forwarded without interception."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Create HTTP GET request (not GenerateContent)
    http_request = (
        b"GET /v1/models HTTP/1.1\r\nHost: generativelanguage.googleapis.com\r\n\r\n"
    )

    client_reader.feed_data(http_request)
    client_reader.feed_eof()

    server_reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n[]")
    server_reader.feed_eof()

    await proxy_server._forward_data_with_interception(
        client_reader,
        client_writer,
        server_reader,
        server_writer,
        host="generativelanguage.googleapis.com",
    )

    # Verify request was forwarded to server
    server_data = server_writer.get_data()
    assert b"GET /v1/models" in server_data

    # Interceptor should not be called for non-GenerateContent
    # (Actually it might be called for response if should_sniff was set by previous request,
    # but for this test with fresh state, it shouldn't intercept)


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_handles_response_and_queues_data(
    proxy_server, mock_interceptor
):
    """Test that intercepted responses are processed and queued."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # GenerateContent request
    http_request = (
        b"POST /v1/models/gemini:generateContent HTTP/1.1\r\n"
        b"Content-Length: 10\r\n"
        b"\r\n"
        b'{"test":1}'
    )

    # Response with headers
    http_response = (
        b"HTTP/1.1 200 OK\r\n"
        b"Content-Type: application/json\r\n"
        b"\r\n"
        b'{"candidates":[{"content":"response"}]}'
    )

    client_reader.feed_data(http_request)
    client_reader.feed_eof()

    server_reader.feed_data(http_response)
    server_reader.feed_eof()

    # Mock interceptor to return specific data
    mock_interceptor.process_response.return_value = {"text": "intercepted response"}

    await proxy_server._forward_data_with_interception(
        client_reader,
        client_writer,
        server_reader,
        server_writer,
        host="generativelanguage.googleapis.com",
    )

    # Verify interceptor was called for response
    mock_interceptor.process_response.assert_called()

    # Verify response was queued (if queue exists)
    # Note: queue operations happen in the code, we can't easily verify without integration test


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_handles_malformed_http_request(proxy_server):
    """Test that malformed HTTP requests are handled gracefully."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Invalid HTTP request (missing parts)
    malformed_request = b"INVALID REQUEST\r\n\r\n"

    client_reader.feed_data(malformed_request)
    client_reader.feed_eof()

    server_reader.feed_data(b"HTTP/1.1 400 Bad Request\r\n\r\n")
    server_reader.feed_eof()

    # Should not crash
    await proxy_server._forward_data_with_interception(
        client_reader,
        client_writer,
        server_reader,
        server_writer,
        host="generativelanguage.googleapis.com",
    )

    # Verify data was still forwarded (fallback behavior)
    server_data = server_writer.get_data()
    assert b"INVALID REQUEST" in server_data


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_handles_chunked_transfer_encoding(proxy_server):
    """Test that chunked responses are handled correctly."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Simple request
    http_request = (
        b"POST /v1/models/gemini:generateContent HTTP/1.1\r\n"
        b"Content-Length: 2\r\n"
        b"\r\n"
        b"{}"
    )

    # Chunked response
    chunked_response = (
        b"HTTP/1.1 200 OK\r\n"
        b"Transfer-Encoding: chunked\r\n"
        b"\r\n"
        b"5\r\nHello\r\n"
        b"6\r\n World\r\n"
        b"0\r\n\r\n"  # End chunk
    )

    client_reader.feed_data(http_request)
    client_reader.feed_eof()

    server_reader.feed_data(chunked_response)
    server_reader.feed_eof()

    await proxy_server._forward_data_with_interception(
        client_reader,
        client_writer,
        server_reader,
        server_writer,
        host="generativelanguage.googleapis.com",
    )

    # Verify chunked data was forwarded to client
    client_data = client_writer.get_data()
    assert b"0\r\n\r\n" in client_data  # End chunk marker


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_cancellation_cleanup(proxy_server):
    """Test that task cancellation properly cleans up interception tasks."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Create task
    task = asyncio.create_task(
        proxy_server._forward_data_with_interception(
            client_reader,
            client_writer,
            server_reader,
            server_writer,
            host="generativelanguage.googleapis.com",
        )
    )

    # Let it start
    await asyncio.sleep(0.1)

    # Cancel
    task.cancel()

    # Should raise CancelledError
    with pytest.raises(asyncio.CancelledError):
        await task

    # Verify connections were closed
    assert client_writer.closed
    assert server_writer.closed


# ==================== TESTS: Edge Cases ====================


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_forward_data_with_large_payload(proxy_server):
    """Test forwarding large payloads (>8192 bytes) without data loss."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Create 100KB payload
    large_payload = b"X" * 100000

    client_reader.feed_data(large_payload)
    client_reader.feed_eof()

    server_reader.feed_eof()

    await proxy_server._forward_data(
        client_reader, client_writer, server_reader, server_writer
    )

    # Verify all data was forwarded
    server_data = server_writer.get_data()
    assert len(server_data) == 100000
    assert server_data == large_payload


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_interception_with_incomplete_headers(proxy_server):
    """Test handling of incomplete HTTP headers (no \\r\\n\\r\\n)."""
    client_reader, client_writer = create_stream_pair()
    server_reader, server_writer = create_stream_pair()

    # Incomplete request (missing final \\r\\n\\r\\n)
    incomplete_request = b"POST /test HTTP/1.1\r\nHost: example.com\r\n"

    client_reader.feed_data(incomplete_request)
    # Don't feed EOF, feed more data after delay
    await asyncio.sleep(0.1)
    client_reader.feed_data(b"\r\n")
    client_reader.feed_eof()

    server_reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n")
    server_reader.feed_eof()

    await proxy_server._forward_data_with_interception(
        client_reader, client_writer, server_reader, server_writer, host="example.com"
    )

    # Should forward data despite incomplete headers
    server_data = server_writer.get_data()
    assert b"POST /test" in server_data