Spaces:
Paused
Paused
| import asyncio | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from fastapi import HTTPException, Request | |
| from api_utils.client_connection import ( | |
| check_client_connection, | |
| setup_disconnect_monitoring, | |
| ) | |
| from models import ClientDisconnectedError | |
| async def test_check_client_connection_success(): | |
| """Test successful client connection check.""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| # Mock _receive to return a non-disconnect message | |
| async def mock_receive(): | |
| return {"type": "http.request"} | |
| request._receive = mock_receive | |
| request.is_disconnected = AsyncMock(return_value=False) | |
| result = await check_client_connection(req_id, request) | |
| assert result is True | |
| async def test_check_client_connection_disconnected(): | |
| """Test client connection check when disconnected.""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| # Mock _receive to return a disconnect message | |
| async def mock_receive(): | |
| return {"type": "http.disconnect"} | |
| request._receive = mock_receive | |
| result = await check_client_connection(req_id, request) | |
| assert result is False | |
| async def test_check_client_connection_timeout(): | |
| """Test client connection check timeout.""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| # Mock _receive to hang | |
| async def mock_receive(): | |
| await asyncio.sleep(1) | |
| return {"type": "http.request"} | |
| request._receive = mock_receive | |
| request.is_disconnected = AsyncMock(return_value=False) | |
| # Should return True on timeout (assuming connected) | |
| result = await check_client_connection(req_id, request) | |
| assert result is True | |
| async def test_check_client_connection_exception(): | |
| """Test client connection check exception.""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| # Mock _receive to raise exception | |
| async def mock_receive(): | |
| raise Exception("Connection error") | |
| request._receive = mock_receive | |
| result = await check_client_connection(req_id, request) | |
| assert result is False | |
| async def test_setup_disconnect_monitoring_active_disconnect(): | |
| """Test disconnect monitoring when client actively disconnects.""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| request.is_disconnected = AsyncMock(return_value=True) | |
| result_future = asyncio.Future() | |
| # Mock check_client_connection to return False (disconnected) | |
| with patch( | |
| "api_utils.client_connection.check_client_connection", new_callable=AsyncMock | |
| ) as mock_test: | |
| mock_test.return_value = False | |
| event, task, check_func = await setup_disconnect_monitoring( | |
| req_id, request, result_future | |
| ) | |
| # Wait for task to process (threshold is 5 consecutive checks at 0.3s each) | |
| await asyncio.sleep(2.0) | |
| assert event.is_set() | |
| assert result_future.done() | |
| with pytest.raises(HTTPException) as exc: | |
| result_future.result() | |
| assert exc.value.status_code == 499 | |
| # Verify check function raises error | |
| with pytest.raises(ClientDisconnectedError): | |
| check_func("test_stage") | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| async def test_setup_disconnect_monitoring_passive_disconnect(): | |
| """Test disconnect monitoring when client passively disconnects (is_disconnected).""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| request.is_disconnected = AsyncMock(return_value=True) | |
| result_future = asyncio.Future() | |
| # Mock check_client_connection to return False, simulating that it detected the disconnect. | |
| with patch( | |
| "api_utils.client_connection.check_client_connection", new_callable=AsyncMock | |
| ) as mock_test: | |
| mock_test.return_value = False | |
| event, task, check_func = await setup_disconnect_monitoring( | |
| req_id, request, result_future | |
| ) | |
| # Wait for task to process (threshold is 5 consecutive checks at 0.3s each) | |
| await asyncio.sleep(2.0) | |
| assert event.is_set() | |
| assert result_future.done() | |
| with pytest.raises(HTTPException) as exc: | |
| result_future.result() | |
| assert exc.value.status_code == 499 | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| async def test_setup_disconnect_monitoring_exception(): | |
| """Test disconnect monitoring handles exceptions.""" | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| result_future = asyncio.Future() | |
| # Mock check_client_connection to raise exception | |
| with patch( | |
| "api_utils.client_connection.check_client_connection", | |
| side_effect=Exception("Monitor error"), | |
| ): | |
| event, task, check_func = await setup_disconnect_monitoring( | |
| req_id, request, result_future | |
| ) | |
| # Wait for task to process | |
| await asyncio.sleep(0.1) | |
| assert event.is_set() | |
| assert result_future.done() | |
| with pytest.raises(HTTPException) as exc: | |
| result_future.result() | |
| assert exc.value.status_code == 500 | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| # ============================================================================ | |
| # Edge Cases - check_client_connection | |
| # ============================================================================ | |
| async def test_check_client_connection_via_is_disconnected(): | |
| """ | |
| Test scenario: _receive timeout, but is_disconnected() returns True | |
| Expected: Return False (line 47) | |
| """ | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| # _receive does not return disconnect immediately, but times out | |
| async def mock_receive(): | |
| await asyncio.sleep(1) # Will timeout in check | |
| return {"type": "http.request"} | |
| request._receive = mock_receive | |
| # is_disconnected() returns True | |
| request.is_disconnected = AsyncMock(return_value=True) | |
| # Execute | |
| result = await check_client_connection(req_id, request) | |
| # Verify: Return False (line 47 executed) | |
| assert result is False | |
| async def test_check_client_connection_outer_exception(): | |
| """ | |
| Test scenario: is_disconnected() throws exception | |
| Expected: Exception is re-raised (outer exception handler re-raises) | |
| """ | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| # _receive timeout | |
| async def mock_receive(): | |
| await asyncio.sleep(1) | |
| return {"type": "http.request"} | |
| request._receive = mock_receive | |
| # is_disconnected() throws exception | |
| request.is_disconnected = AsyncMock(side_effect=Exception("is_disconnected error")) | |
| # Execute and verify exception is re-raised | |
| with pytest.raises(Exception, match="is_disconnected error"): | |
| await check_client_connection(req_id, request) | |
| # ============================================================================ | |
| # Edge Cases - setup_disconnect_monitoring | |
| # ============================================================================ | |
| async def test_setup_disconnect_monitoring_client_stays_connected(): | |
| """ | |
| Test scenario: Client stays connected, result_future completed by other task | |
| Expected: Monitoring task loops normally, executes sleep | |
| """ | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| request.is_disconnected = AsyncMock(return_value=False) | |
| result_future = asyncio.Future() | |
| # Track check calls | |
| check_count = 0 | |
| async def mock_check_connected(*args, **kwargs): | |
| nonlocal check_count | |
| check_count += 1 | |
| if check_count >= 3: | |
| # Complete the future to stop the loop | |
| if not result_future.done(): | |
| result_future.set_result({"status": "success"}) | |
| return True # Client stays connected | |
| with patch( | |
| "api_utils.client_connection.check_client_connection", | |
| new_callable=AsyncMock, | |
| side_effect=mock_check_connected, | |
| ): | |
| event, task, check_func = await setup_disconnect_monitoring( | |
| req_id, request, result_future | |
| ) | |
| # Wait for multiple checks (0.3s sleep each in the monitoring loop) | |
| await asyncio.sleep(1.2) | |
| # Verify: Multiple checks performed | |
| assert check_count >= 3 | |
| # Verify: future completed normally | |
| assert result_future.done() | |
| assert result_future.result() == {"status": "success"} | |
| # Verify: event not set (no disconnect) | |
| assert not event.is_set() | |
| # Cleanup | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| async def test_setup_disconnect_monitoring_task_cancelled(): | |
| """ | |
| Test scenario: Monitoring task cancelled | |
| Expected: CancelledError caught, task exits gracefully | |
| """ | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| request.is_disconnected = AsyncMock(return_value=False) | |
| result_future = asyncio.Future() | |
| # Mock check to return True (connected), so it enters the sleep | |
| with patch( | |
| "api_utils.client_connection.check_client_connection", | |
| new_callable=AsyncMock, | |
| return_value=True, | |
| ): | |
| event, task, check_func = await setup_disconnect_monitoring( | |
| req_id, request, result_future | |
| ) | |
| # Give it time to start one check cycle | |
| await asyncio.sleep(0.1) | |
| # Execute: Cancel task | |
| task.cancel() | |
| # Verify: Task cancelled | |
| # Task catches CancelledError and exits gracefully, will not re-throw | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| # If it does raise, that's also fine | |
| pass | |
| # Verify: Task done | |
| assert task.done() | |
| # Verify: event not set (task cancelled, not disconnect) | |
| assert not event.is_set() | |
| async def test_check_client_disconnected_not_disconnected(): | |
| """ | |
| Test scenario: Call check_client_disconnected() but event not set | |
| Expected: Return False, no exception thrown | |
| """ | |
| req_id = "test_req" | |
| request = MagicMock(spec=Request) | |
| request.is_disconnected = AsyncMock(return_value=False) | |
| result_future = asyncio.Future() | |
| # Mock check to keep client connected | |
| with patch( | |
| "api_utils.client_connection.check_client_connection", | |
| new_callable=AsyncMock, | |
| return_value=True, | |
| ): | |
| event, task, check_func = await setup_disconnect_monitoring( | |
| req_id, request, result_future | |
| ) | |
| # Wait a bit but don't let it disconnect | |
| await asyncio.sleep(0.1) | |
| # Execute: Call check function | |
| result = check_func("test_stage") | |
| # Verify: Return False, no exception thrown | |
| assert result is False | |
| # Verify: event not set | |
| assert not event.is_set() | |
| # Cleanup | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |