AIstudioProxyAPI / tests /api_utils /test_client_connection.py
peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
11.4 kB
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request
from api_utils.client_connection import (
check_client_connection,
setup_disconnect_monitoring,
)
from models import ClientDisconnectedError
@pytest.mark.asyncio
async def test_check_client_connection_success():
"""Test successful client connection check."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to return a non-disconnect message
async def mock_receive():
return {"type": "http.request"}
request._receive = mock_receive
request.is_disconnected = AsyncMock(return_value=False)
result = await check_client_connection(req_id, request)
assert result is True
@pytest.mark.asyncio
async def test_check_client_connection_disconnected():
"""Test client connection check when disconnected."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to return a disconnect message
async def mock_receive():
return {"type": "http.disconnect"}
request._receive = mock_receive
result = await check_client_connection(req_id, request)
assert result is False
@pytest.mark.asyncio
async def test_check_client_connection_timeout():
"""Test client connection check timeout."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to hang
async def mock_receive():
await asyncio.sleep(1)
return {"type": "http.request"}
request._receive = mock_receive
request.is_disconnected = AsyncMock(return_value=False)
# Should return True on timeout (assuming connected)
result = await check_client_connection(req_id, request)
assert result is True
@pytest.mark.asyncio
async def test_check_client_connection_exception():
"""Test client connection check exception."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to raise exception
async def mock_receive():
raise Exception("Connection error")
request._receive = mock_receive
result = await check_client_connection(req_id, request)
assert result is False
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_active_disconnect():
"""Test disconnect monitoring when client actively disconnects."""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=True)
result_future = asyncio.Future()
# Mock check_client_connection to return False (disconnected)
with patch(
"api_utils.client_connection.check_client_connection", new_callable=AsyncMock
) as mock_test:
mock_test.return_value = False
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for task to process (threshold is 5 consecutive checks at 0.3s each)
await asyncio.sleep(2.0)
assert event.is_set()
assert result_future.done()
with pytest.raises(HTTPException) as exc:
result_future.result()
assert exc.value.status_code == 499
# Verify check function raises error
with pytest.raises(ClientDisconnectedError):
check_func("test_stage")
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_passive_disconnect():
"""Test disconnect monitoring when client passively disconnects (is_disconnected)."""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=True)
result_future = asyncio.Future()
# Mock check_client_connection to return False, simulating that it detected the disconnect.
with patch(
"api_utils.client_connection.check_client_connection", new_callable=AsyncMock
) as mock_test:
mock_test.return_value = False
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for task to process (threshold is 5 consecutive checks at 0.3s each)
await asyncio.sleep(2.0)
assert event.is_set()
assert result_future.done()
with pytest.raises(HTTPException) as exc:
result_future.result()
assert exc.value.status_code == 499
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_exception():
"""Test disconnect monitoring handles exceptions."""
req_id = "test_req"
request = MagicMock(spec=Request)
result_future = asyncio.Future()
# Mock check_client_connection to raise exception
with patch(
"api_utils.client_connection.check_client_connection",
side_effect=Exception("Monitor error"),
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for task to process
await asyncio.sleep(0.1)
assert event.is_set()
assert result_future.done()
with pytest.raises(HTTPException) as exc:
result_future.result()
assert exc.value.status_code == 500
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# ============================================================================
# Edge Cases - check_client_connection
# ============================================================================
@pytest.mark.asyncio
async def test_check_client_connection_via_is_disconnected():
"""
Test scenario: _receive timeout, but is_disconnected() returns True
Expected: Return False (line 47)
"""
req_id = "test_req"
request = MagicMock(spec=Request)
# _receive does not return disconnect immediately, but times out
async def mock_receive():
await asyncio.sleep(1) # Will timeout in check
return {"type": "http.request"}
request._receive = mock_receive
# is_disconnected() returns True
request.is_disconnected = AsyncMock(return_value=True)
# Execute
result = await check_client_connection(req_id, request)
# Verify: Return False (line 47 executed)
assert result is False
@pytest.mark.asyncio
async def test_check_client_connection_outer_exception():
"""
Test scenario: is_disconnected() throws exception
Expected: Exception is re-raised (outer exception handler re-raises)
"""
req_id = "test_req"
request = MagicMock(spec=Request)
# _receive timeout
async def mock_receive():
await asyncio.sleep(1)
return {"type": "http.request"}
request._receive = mock_receive
# is_disconnected() throws exception
request.is_disconnected = AsyncMock(side_effect=Exception("is_disconnected error"))
# Execute and verify exception is re-raised
with pytest.raises(Exception, match="is_disconnected error"):
await check_client_connection(req_id, request)
# ============================================================================
# Edge Cases - setup_disconnect_monitoring
# ============================================================================
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_client_stays_connected():
"""
Test scenario: Client stays connected, result_future completed by other task
Expected: Monitoring task loops normally, executes sleep
"""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=False)
result_future = asyncio.Future()
# Track check calls
check_count = 0
async def mock_check_connected(*args, **kwargs):
nonlocal check_count
check_count += 1
if check_count >= 3:
# Complete the future to stop the loop
if not result_future.done():
result_future.set_result({"status": "success"})
return True # Client stays connected
with patch(
"api_utils.client_connection.check_client_connection",
new_callable=AsyncMock,
side_effect=mock_check_connected,
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for multiple checks (0.3s sleep each in the monitoring loop)
await asyncio.sleep(1.2)
# Verify: Multiple checks performed
assert check_count >= 3
# Verify: future completed normally
assert result_future.done()
assert result_future.result() == {"status": "success"}
# Verify: event not set (no disconnect)
assert not event.is_set()
# Cleanup
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_task_cancelled():
"""
Test scenario: Monitoring task cancelled
Expected: CancelledError caught, task exits gracefully
"""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=False)
result_future = asyncio.Future()
# Mock check to return True (connected), so it enters the sleep
with patch(
"api_utils.client_connection.check_client_connection",
new_callable=AsyncMock,
return_value=True,
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Give it time to start one check cycle
await asyncio.sleep(0.1)
# Execute: Cancel task
task.cancel()
# Verify: Task cancelled
# Task catches CancelledError and exits gracefully, will not re-throw
try:
await task
except asyncio.CancelledError:
# If it does raise, that's also fine
pass
# Verify: Task done
assert task.done()
# Verify: event not set (task cancelled, not disconnect)
assert not event.is_set()
@pytest.mark.asyncio
async def test_check_client_disconnected_not_disconnected():
"""
Test scenario: Call check_client_disconnected() but event not set
Expected: Return False, no exception thrown
"""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=False)
result_future = asyncio.Future()
# Mock check to keep client connected
with patch(
"api_utils.client_connection.check_client_connection",
new_callable=AsyncMock,
return_value=True,
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait a bit but don't let it disconnect
await asyncio.sleep(0.1)
# Execute: Call check function
result = check_func("test_stage")
# Verify: Return False, no exception thrown
assert result is False
# Verify: event not set
assert not event.is_set()
# Cleanup
task.cancel()
try:
await task
except asyncio.CancelledError:
pass