File size: 17,530 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
"""
Integration tests for streaming response generation with real async generators.

These tests verify actual async generator behavior with real asyncio primitives,
ensuring streaming works correctly end-to-end without over-mocking.

Test Strategy:
- Use REAL async generators (not mocked iterators)
- Use REAL asyncio.Event for completion signaling
- Test actual SSE chunk format and ordering
- Verify real async behavior and backpressure
- Mock only data sources (stream responses)

Coverage Target: Stream generator integrity and async behavior
"""

import asyncio
import json
from unittest.mock import MagicMock, patch

import pytest

from api_utils.response_generators import gen_sse_from_aux_stream
from models import ClientDisconnectedError


@pytest.mark.integration
class TestStreamingGeneratorBehavior:
    """Integration tests for real async generator behavior."""

    async def test_generator_yields_actual_async_iterations(self, make_chat_request):
        """
        Test that generator actually yields asynchronously.

        Verifies real async iteration behavior, not just mock iteration.
        """
        req_id = "int-test-1"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()
        check_disconnect = MagicMock()

        stream_data = [
            {"body": "First", "reason": "", "done": False},
            {"body": "First Second", "reason": "", "done": False},
            {"body": "First Second Third", "reason": "", "done": True},
        ]

        iteration_log = []

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for idx, item in enumerate(stream_data):
                iteration_log.append(f"yield_{idx}")
                await asyncio.sleep(0.01)  # Simulate async delay
                yield item

        with (
            patch(
                "api_utils.response_generators.use_stream_response",
                side_effect=mock_stream_gen,
            ),
            patch(
                "api_utils.response_generators.calculate_usage_stats",
                return_value={"total_tokens": 5},
            ),
        ):
            chunks = []
            async for chunk in gen_sse_from_aux_stream(
                req_id,
                request,
                "gemini-1.5-pro",
                check_disconnect,
                completion_event,
                5.0,
            ):
                chunks.append(chunk)
                iteration_log.append(f"received_{len(chunks) - 1}")

        # Verify async iterations actually happened
        assert "yield_0" in iteration_log
        assert "received_0" in iteration_log
        assert "yield_1" in iteration_log

        # Verify interleaving (proves async behavior)
        yield_0_idx = iteration_log.index("yield_0")
        recv_0_idx = iteration_log.index("received_0")
        assert recv_0_idx > yield_0_idx

    async def test_concurrent_stream_consumption(self, make_chat_request):
        """
        Test multiple concurrent consumers of different streams.

        Verifies generators are independent and don't interfere.
        """
        completion_event1 = asyncio.Event()
        completion_event2 = asyncio.Event()
        check_disconnect = MagicMock()

        stream_data1 = [
            {"body": "Stream1 Data", "reason": "", "done": True},
        ]

        stream_data2 = [
            {"body": "Stream2 Data", "reason": "", "done": True},
        ]

        async def mock_stream_gen1(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data1:
                await asyncio.sleep(0.02)
                yield item

        async def mock_stream_gen2(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data2:
                await asyncio.sleep(0.01)
                yield item

        async def consume_stream(req_id, stream_gen, event):
            """Consume a single stream."""
            with (
                patch(
                    "api_utils.response_generators.use_stream_response",
                    side_effect=stream_gen,
                ),
                patch(
                    "api_utils.response_generators.calculate_usage_stats",
                    return_value={"total_tokens": 3},
                ),
            ):
                chunks = []
                async for chunk in gen_sse_from_aux_stream(
                    req_id,
                    make_chat_request(stream=True),
                    "model",
                    check_disconnect,
                    event,
                    5.0,
                ):
                    chunks.append(chunk)
                return chunks

        # Consume both streams concurrently
        task1 = asyncio.create_task(
            consume_stream("req1", mock_stream_gen1, completion_event1)
        )
        task2 = asyncio.create_task(
            consume_stream("req2", mock_stream_gen2, completion_event2)
        )

        chunks1, chunks2 = await asyncio.gather(task1, task2)

        # Both should complete independently
        assert len(chunks1) > 0
        assert len(chunks2) > 0
        assert completion_event1.is_set()
        assert completion_event2.is_set()

        # Verify content separation
        content1 = "".join(chunks1)
        content2 = "".join(chunks2)
        assert "Stream1" in content1
        assert "Stream2" in content2

    async def test_backpressure_handling(self, make_chat_request):
        """
        Test generator handles backpressure (slow consumer).

        Verifies generator doesn't lose data when consumer is slow.
        """
        req_id = "int-test-backpressure"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()
        check_disconnect = MagicMock()

        # Large stream to test buffering
        stream_data = [
            {"body": f"Chunk {i}", "reason": "", "done": i == 49} for i in range(50)
        ]

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data:
                yield item

        with (
            patch(
                "api_utils.response_generators.use_stream_response",
                side_effect=mock_stream_gen,
            ),
            patch(
                "api_utils.response_generators.calculate_usage_stats",
                return_value={"total_tokens": 100},
            ),
        ):
            chunks = []
            async for chunk in gen_sse_from_aux_stream(
                req_id,
                request,
                "gemini-1.5-pro",
                check_disconnect,
                completion_event,
                5.0,
            ):
                chunks.append(chunk)
                # Simulate slow consumer
                await asyncio.sleep(0.001)

        # Should receive all chunks despite slow consumption
        # Filter out [DONE] and usage chunks
        content_chunks = [
            c for c in chunks if "[DONE]" not in c and "usage" not in c.lower()
        ]

        # Generator creates delta chunks from full responses,
        # so expect fewer chunks than input items (deltas are calculated)
        # With 50 items, we get ~2-4 delta chunks typically
        assert len(content_chunks) >= 2, (
            f"Expected at least 2 chunks, got {len(content_chunks)}"
        )

        # Verify all chunks were processed (no data loss)
        assert len(chunks) > 0

        # Verify completion event was set
        assert completion_event.is_set()

    async def test_completion_event_timing(self, make_chat_request):
        """
        Test that completion event is set at the right time.

        Verifies event is set only after streaming completes.
        """
        req_id = "int-test-event"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()
        check_disconnect = MagicMock()

        stream_data = [
            {"body": "A", "reason": "", "done": False},
            {"body": "AB", "reason": "", "done": False},
            {"body": "ABC", "reason": "", "done": True},
        ]

        event_check_log = []

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data:
                yield item

        with (
            patch(
                "api_utils.response_generators.use_stream_response",
                side_effect=mock_stream_gen,
            ),
            patch(
                "api_utils.response_generators.calculate_usage_stats",
                return_value={"total_tokens": 3},
            ),
        ):
            async for chunk in gen_sse_from_aux_stream(
                req_id,
                request,
                "gemini-1.5-pro",
                check_disconnect,
                completion_event,
                5.0,
            ):
                # Record event state after each chunk
                event_check_log.append(completion_event.is_set())

        # Event should not be set initially
        assert event_check_log[0] is False

        # Event should eventually be set
        assert completion_event.is_set()


@pytest.mark.integration
class TestStreamingErrorHandling:
    """Integration tests for error handling in streaming."""

    async def test_generator_cleanup_on_exception(self, make_chat_request):
        """
        Test that generator properly cleans up on exception.

        Verifies completion event is set even on error.
        """
        req_id = "int-test-error"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()
        check_disconnect = MagicMock()

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            yield {"body": "First chunk", "reason": "", "done": False}
            raise Exception("Stream error")

        with patch(
            "api_utils.response_generators.use_stream_response",
            side_effect=mock_stream_gen,
        ):
            chunks = []
            try:
                async for chunk in gen_sse_from_aux_stream(
                    req_id,
                    request,
                    "gemini-1.5-pro",
                    check_disconnect,
                    completion_event,
                    5.0,
                ):
                    chunks.append(chunk)
            except Exception:
                pass  # Expected

        # Completion event should be set for cleanup
        assert completion_event.is_set()

    async def test_disconnect_during_streaming(self, make_chat_request):
        """
        Test client disconnect detection during active streaming.

        Verifies generator stops cleanly on disconnect.
        """
        req_id = "int-test-disconnect"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()

        # Mock disconnect after 2 chunks
        check_disconnect = MagicMock()
        check_disconnect.side_effect = [
            None,  # First chunk OK
            None,  # Second chunk OK
            ClientDisconnectedError("Disconnected"),  # Third chunk fails
        ]

        stream_data = [
            {"body": f"Chunk {i}", "reason": "", "done": False} for i in range(10)
        ]

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data:
                yield item

        with patch(
            "api_utils.response_generators.use_stream_response",
            side_effect=mock_stream_gen,
        ):
            chunks = []
            async for chunk in gen_sse_from_aux_stream(
                req_id,
                request,
                "gemini-1.5-pro",
                check_disconnect,
                completion_event,
                5.0,
            ):
                chunks.append(chunk)

        # Should stop early (less than 10 chunks)
        assert len(chunks) < 10

        # Completion event should be set
        assert completion_event.is_set()


@pytest.mark.integration
class TestStreamingDataIntegrity:
    """Integration tests for data integrity in streaming."""

    async def test_incremental_content_deltas(self, make_chat_request):
        """
        Test that content deltas are correctly calculated.

        Verifies incremental updates show only new content.
        """
        req_id = "int-test-deltas"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()
        check_disconnect = MagicMock()

        stream_data = [
            {"body": "Hello", "reason": "", "done": False},
            {"body": "Hello world", "reason": "", "done": False},
            {"body": "Hello world!", "reason": "", "done": True},
        ]

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data:
                yield item

        with (
            patch(
                "api_utils.response_generators.use_stream_response",
                side_effect=mock_stream_gen,
            ),
            patch(
                "api_utils.response_generators.calculate_usage_stats",
                return_value={"total_tokens": 5},
            ),
        ):
            chunks = []
            async for chunk in gen_sse_from_aux_stream(
                req_id,
                request,
                "gemini-1.5-pro",
                check_disconnect,
                completion_event,
                5.0,
            ):
                if "[DONE]" not in chunk:
                    chunks.append(chunk)

        # Parse deltas
        deltas = []
        for chunk in chunks:
            try:
                data = json.loads(chunk.replace("data: ", "").strip())
                if "choices" in data and data["choices"]:
                    delta_content = data["choices"][0].get("delta", {}).get("content")
                    if delta_content:
                        deltas.append(delta_content)
            except (json.JSONDecodeError, KeyError):
                continue

        # First delta should be "Hello"
        assert deltas[0] == "Hello"

        # Second delta should be " world" (only new content)
        assert deltas[1] == " world"

        # Third delta should be "!" (only new content)
        assert deltas[2] == "!"

        # Accumulated deltas should match final content
        assert "".join(deltas) == "Hello world!"

    async def test_sse_format_compliance(self, make_chat_request):
        """
        Test that SSE format is compliant with spec.

        Verifies data: prefix and proper line endings.
        """
        req_id = "int-test-sse"
        request = make_chat_request(stream=True)
        completion_event = asyncio.Event()
        check_disconnect = MagicMock()

        stream_data = [
            {"body": "Test", "reason": "", "done": True},
        ]

        async def mock_stream_gen(
            rid,
            timeout=5.0,
            page=None,
            check_client_disconnected=None,
            enable_silence_detection=True,
            **kwargs,
        ):
            for item in stream_data:
                yield item

        with (
            patch(
                "api_utils.response_generators.use_stream_response",
                side_effect=mock_stream_gen,
            ),
            patch(
                "api_utils.response_generators.calculate_usage_stats",
                return_value={"total_tokens": 1},
            ),
        ):
            chunks = []
            async for chunk in gen_sse_from_aux_stream(
                req_id,
                request,
                "gemini-1.5-pro",
                check_disconnect,
                completion_event,
                5.0,
            ):
                chunks.append(chunk)

        # Verify SSE format
        for chunk in chunks:
            if "[DONE]" not in chunk:
                # Should start with "data: "
                assert chunk.startswith("data: ")

                # Should be valid JSON after "data: "
                json_part = chunk.replace("data: ", "").strip()
                try:
                    parsed = json.loads(json_part)
                    assert "choices" in parsed
                    assert "model" in parsed
                except json.JSONDecodeError:
                    pytest.fail(f"Invalid JSON in SSE chunk: {chunk}")