File size: 12,073 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from stream.proxy_server import ProxyServer


@pytest.fixture
def mock_deps():
    with (
        patch("stream.proxy_server.CertificateManager") as MockCertMgr,
        patch("stream.proxy_server.ProxyConnector") as MockConnector,
        patch("stream.proxy_server.HttpInterceptor") as MockInterceptor,
        patch("pathlib.Path.mkdir"),
    ):
        mock_cert_mgr = MockCertMgr.return_value
        mock_connector = MockConnector.return_value
        mock_interceptor = MockInterceptor.return_value

        yield {
            "cert_mgr": mock_cert_mgr,
            "connector": mock_connector,
            "interceptor": mock_interceptor,
        }


@pytest.fixture
def proxy_server(mock_deps):
    with patch("logging.getLogger"):
        server = ProxyServer(
            host="127.0.0.1", port=8080, intercept_domains=["example.com"]
        )
        return server


@pytest.mark.asyncio
async def test_handle_client_empty_request(proxy_server):
    """Test handling client with empty request line."""
    mock_reader = AsyncMock()
    mock_writer = MagicMock()
    mock_writer.wait_closed = AsyncMock()

    mock_reader.readline.return_value = b""

    await proxy_server.handle_client(mock_reader, mock_writer)

    mock_writer.close.assert_called()


@pytest.mark.asyncio
async def test_handle_client_exception(proxy_server):
    """Test handling client with exception during read."""
    mock_reader = AsyncMock()
    mock_writer = MagicMock()
    mock_writer.wait_closed = AsyncMock()

    mock_reader.readline.side_effect = Exception("Read error")

    await proxy_server.handle_client(mock_reader, mock_writer)

    # Logger should have logged the error
    proxy_server.logger.error.assert_called()
    mock_writer.close.assert_called()


@pytest.mark.asyncio
async def test_handle_connect_no_transport(proxy_server, mock_deps):
    """Test CONNECT when transport is None."""
    mock_reader = AsyncMock()
    mock_writer = MagicMock()
    mock_writer.drain = AsyncMock()  # Fix: drain must be awaitable

    # Setup for interception
    proxy_server.should_intercept = MagicMock(return_value=True)

    # Mock transport as None
    mock_writer.transport = None

    await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443")

    # Should warn and return
    proxy_server.logger.warning.assert_called_with(
        "Client writer transport is None for example.com:443 before TLS upgrade. Closing."
    )


@pytest.mark.asyncio
async def test_handle_connect_start_tls_fail(proxy_server, mock_deps):
    """Test CONNECT when start_tls returns None."""
    mock_reader = AsyncMock()
    mock_writer = MagicMock()
    mock_writer.drain = AsyncMock()  # Fix: drain must be awaitable

    mock_transport = MagicMock()
    mock_writer.transport = mock_transport

    proxy_server.should_intercept = MagicMock(return_value=True)

    mock_loop = MagicMock()
    mock_loop.start_tls = AsyncMock(return_value=None)

    with (
        patch("asyncio.get_running_loop", return_value=mock_loop),
        patch("ssl.create_default_context"),
    ):
        await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443")

        proxy_server.logger.error.assert_called_with(
            "loop.start_tls returned None for example.com:443, which is unexpected. Closing connection.",
            exc_info=True,
        )
        mock_writer.close.assert_called()


@pytest.mark.asyncio
async def test_forward_data_with_interception_invalid_http(proxy_server, mock_deps):
    """Test forwarding with interception when request line is invalid."""
    client_reader = AsyncMock()
    client_writer = MagicMock()
    server_reader = AsyncMock()
    server_writer = MagicMock()
    server_writer.drain = AsyncMock()

    # Capture written data because client_buffer is cleared
    written_data = []

    def capture_write(data):
        written_data.append(bytes(data))  # Convert bytearray to bytes copy

    server_writer.write.side_effect = capture_write

    # Invalid HTTP request (no spaces)
    invalid_request = b"INVALID_REQUEST\r\nHeader: val\r\n\r\n"
    client_reader.read.side_effect = [invalid_request, b""]
    server_reader.read.return_value = b""  # Server closes immediately

    await proxy_server._forward_data_with_interception(
        client_reader, client_writer, server_reader, server_writer, "example.com"
    )

    # Should have forwarded raw buffer
    assert invalid_request in written_data


@pytest.mark.asyncio
async def test_forward_data_with_interception_partial_data(proxy_server, mock_deps):
    """Test forwarding with interception when data arrives in chunks."""
    client_reader = AsyncMock()
    MagicMock()
    server_reader = AsyncMock()
    server_writer = MagicMock()
    server_writer.drain = AsyncMock()

    # Split request into chunks
    chunk1 = b"POST /path "
    chunk2 = b"HTTP/1.1\r\n"
    chunk3 = b"Host: example.com\r\n\r\nBody"

    client_reader.read.side_effect = [chunk1, chunk2, chunk3, b""]
    server_reader.read.return_value = b""  # Server closes immediately

    # Setup interceptor to avoid errors
    mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed")

    # See explanation in bug reproduction test
    pass


@pytest.mark.asyncio
async def test_forward_data_with_interception_split_headers_bug_reproduction(
    proxy_server, mock_deps
):
    """
    Test that split headers cause interception to be skipped (or fail).
    This test documents current behavior which might be buggy.
    """
    client_reader = AsyncMock()
    client_writer = MagicMock()
    server_reader = AsyncMock()
    server_writer = MagicMock()
    server_writer.drain = AsyncMock()

    chunk1 = b"POST /GenerateContent HTTP/1.1\r\nHost: e"
    chunk2 = b"xample.com\r\n\r\nBody"

    client_reader.read.side_effect = [chunk1, chunk2, b""]
    server_reader.read.return_value = b""  # Server closes immediately

    await proxy_server._forward_data_with_interception(
        client_reader, client_writer, server_reader, server_writer, "example.com"
    )

    # Because of the potential bug, it will forward chunk1 immediately.
    server_writer.write.assert_any_call(chunk1)

    # Verify process_request was NOT called
    assert not mock_deps["interceptor"].process_request.called


@pytest.mark.asyncio
async def test_forward_data_with_interception_http_error_response(
    proxy_server, mock_deps
):
    """Test interception with HTTP error response (4xx/5xx).

    When the upstream returns an error status, we should:
    1. Log the error
    2. Send an error payload to the queue (fail-fast)
    """
    import json

    client_reader = AsyncMock()
    client_writer = MagicMock()
    client_writer.write = MagicMock()
    server_reader = AsyncMock()
    server_writer = MagicMock()
    server_writer.drain = AsyncMock()

    # Setup a queue to capture error payload
    mock_queue = MagicMock()
    proxy_server.queue = mock_queue

    # Client sends a GenerateContent request
    client_data = b"POST /GenerateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
    client_reader.read.side_effect = [client_data, b""]

    # Server returns 429 Too Many Requests
    server_response = b"HTTP/1.1 429 Too Many Requests\r\nContent-Type: text/plain\r\n\r\nRate limited"
    server_reader.read.side_effect = [server_response, b""]

    # Setup interceptor
    mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed")

    await proxy_server._forward_data_with_interception(
        client_reader, client_writer, server_reader, server_writer, "example.com"
    )

    # Verify error payload was sent to queue
    mock_queue.put.assert_called()
    call_args = mock_queue.put.call_args[0][0]
    parsed = json.loads(call_args)
    assert parsed["error"] is True
    assert parsed["status"] == 429
    assert "429" in parsed["message"]
    assert parsed["done"] is True


@pytest.mark.asyncio
async def test_handle_connect_connection_error_with_interception(
    proxy_server, mock_deps
):
    """Test CONNECT handling when connection to server fails (with interception).

    When create_connection fails after TLS upgrade, we should:
    1. Log the error
    2. Close the client writer properly
    """
    mock_reader = AsyncMock()
    mock_writer = MagicMock()
    mock_writer.drain = AsyncMock()
    mock_writer.close = MagicMock()
    mock_writer.wait_closed = AsyncMock()

    # Setup for interception
    proxy_server.should_intercept = MagicMock(return_value=True)

    mock_transport = MagicMock()
    mock_writer.transport = mock_transport

    mock_loop = MagicMock()
    new_transport = MagicMock()
    mock_loop.start_tls = AsyncMock(return_value=new_transport)

    # Mock connector to raise connection error
    mock_deps["connector"].create_connection = AsyncMock(
        side_effect=ConnectionRefusedError("Connection refused")
    )

    with (
        patch("asyncio.get_running_loop", return_value=mock_loop),
        patch("ssl.create_default_context"),
        patch("asyncio.StreamWriter"),  # Mock StreamWriter constructor
    ):
        await proxy_server._handle_connect(mock_reader, mock_writer, "example.com:443")

        # Should log the error
        proxy_server.logger.error.assert_called()
        error_call = str(proxy_server.logger.error.call_args)
        assert "example.com" in error_call


@pytest.mark.asyncio
async def test_handle_connect_connection_error_no_interception(proxy_server, mock_deps):
    """Test CONNECT handling when connection to server fails (no interception).

    When create_connection fails without interception, we should:
    1. Log the error
    2. Close the original writer properly
    """
    mock_reader = AsyncMock()
    mock_writer = MagicMock()
    mock_writer.drain = AsyncMock()
    mock_writer.close = MagicMock()
    mock_writer.wait_closed = AsyncMock()

    # Setup for no interception
    proxy_server.should_intercept = MagicMock(return_value=False)

    # Mock connector to raise connection error
    mock_deps["connector"].create_connection = AsyncMock(
        side_effect=TimeoutError("Connection timed out")
    )

    await proxy_server._handle_connect(mock_reader, mock_writer, "other.com:443")

    # Should log the error
    proxy_server.logger.error.assert_called()
    error_call = str(proxy_server.logger.error.call_args)
    assert "other.com" in error_call

    # Should close the writer
    mock_writer.close.assert_called()


@pytest.mark.asyncio
async def test_forward_data_with_interception_response_parsing_error(
    proxy_server, mock_deps
):
    """Test interception when response parsing fails.

    If the interceptor throws during response processing, we should:
    1. Log the error
    2. Continue forwarding data
    """
    client_reader = AsyncMock()
    client_writer = MagicMock()
    client_writer.write = MagicMock()
    server_reader = AsyncMock()
    server_writer = MagicMock()
    server_writer.drain = AsyncMock()

    # Client sends a GenerateContent request
    client_data = b"POST /GenerateContent HTTP/1.1\r\nHost: example.com\r\n\r\nBody"
    client_reader.read.side_effect = [client_data, b""]

    # Server returns a valid 200 response
    server_response = (
        b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nResponse body"
    )
    server_reader.read.side_effect = [server_response, b""]

    # Setup interceptor to fail during response processing
    mock_deps["interceptor"].process_request = AsyncMock(return_value=b"processed")
    mock_deps["interceptor"].process_response = AsyncMock(
        side_effect=ValueError("Failed to parse response")
    )

    await proxy_server._forward_data_with_interception(
        client_reader, client_writer, server_reader, server_writer, "example.com"
    )

    # Should log the error
    proxy_server.logger.error.assert_called()
    error_call = str(proxy_server.logger.error.call_args)
    assert "interception" in error_call.lower() or "response" in error_call.lower()