Spaces:
Paused
Paused
File size: 11,407 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 | import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request
from api_utils.client_connection import (
check_client_connection,
setup_disconnect_monitoring,
)
from models import ClientDisconnectedError
@pytest.mark.asyncio
async def test_check_client_connection_success():
"""Test successful client connection check."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to return a non-disconnect message
async def mock_receive():
return {"type": "http.request"}
request._receive = mock_receive
request.is_disconnected = AsyncMock(return_value=False)
result = await check_client_connection(req_id, request)
assert result is True
@pytest.mark.asyncio
async def test_check_client_connection_disconnected():
"""Test client connection check when disconnected."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to return a disconnect message
async def mock_receive():
return {"type": "http.disconnect"}
request._receive = mock_receive
result = await check_client_connection(req_id, request)
assert result is False
@pytest.mark.asyncio
async def test_check_client_connection_timeout():
"""Test client connection check timeout."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to hang
async def mock_receive():
await asyncio.sleep(1)
return {"type": "http.request"}
request._receive = mock_receive
request.is_disconnected = AsyncMock(return_value=False)
# Should return True on timeout (assuming connected)
result = await check_client_connection(req_id, request)
assert result is True
@pytest.mark.asyncio
async def test_check_client_connection_exception():
"""Test client connection check exception."""
req_id = "test_req"
request = MagicMock(spec=Request)
# Mock _receive to raise exception
async def mock_receive():
raise Exception("Connection error")
request._receive = mock_receive
result = await check_client_connection(req_id, request)
assert result is False
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_active_disconnect():
"""Test disconnect monitoring when client actively disconnects."""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=True)
result_future = asyncio.Future()
# Mock check_client_connection to return False (disconnected)
with patch(
"api_utils.client_connection.check_client_connection", new_callable=AsyncMock
) as mock_test:
mock_test.return_value = False
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for task to process (threshold is 5 consecutive checks at 0.3s each)
await asyncio.sleep(2.0)
assert event.is_set()
assert result_future.done()
with pytest.raises(HTTPException) as exc:
result_future.result()
assert exc.value.status_code == 499
# Verify check function raises error
with pytest.raises(ClientDisconnectedError):
check_func("test_stage")
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_passive_disconnect():
"""Test disconnect monitoring when client passively disconnects (is_disconnected)."""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=True)
result_future = asyncio.Future()
# Mock check_client_connection to return False, simulating that it detected the disconnect.
with patch(
"api_utils.client_connection.check_client_connection", new_callable=AsyncMock
) as mock_test:
mock_test.return_value = False
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for task to process (threshold is 5 consecutive checks at 0.3s each)
await asyncio.sleep(2.0)
assert event.is_set()
assert result_future.done()
with pytest.raises(HTTPException) as exc:
result_future.result()
assert exc.value.status_code == 499
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_exception():
"""Test disconnect monitoring handles exceptions."""
req_id = "test_req"
request = MagicMock(spec=Request)
result_future = asyncio.Future()
# Mock check_client_connection to raise exception
with patch(
"api_utils.client_connection.check_client_connection",
side_effect=Exception("Monitor error"),
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for task to process
await asyncio.sleep(0.1)
assert event.is_set()
assert result_future.done()
with pytest.raises(HTTPException) as exc:
result_future.result()
assert exc.value.status_code == 500
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# ============================================================================
# Edge Cases - check_client_connection
# ============================================================================
@pytest.mark.asyncio
async def test_check_client_connection_via_is_disconnected():
"""
Test scenario: _receive timeout, but is_disconnected() returns True
Expected: Return False (line 47)
"""
req_id = "test_req"
request = MagicMock(spec=Request)
# _receive does not return disconnect immediately, but times out
async def mock_receive():
await asyncio.sleep(1) # Will timeout in check
return {"type": "http.request"}
request._receive = mock_receive
# is_disconnected() returns True
request.is_disconnected = AsyncMock(return_value=True)
# Execute
result = await check_client_connection(req_id, request)
# Verify: Return False (line 47 executed)
assert result is False
@pytest.mark.asyncio
async def test_check_client_connection_outer_exception():
"""
Test scenario: is_disconnected() throws exception
Expected: Exception is re-raised (outer exception handler re-raises)
"""
req_id = "test_req"
request = MagicMock(spec=Request)
# _receive timeout
async def mock_receive():
await asyncio.sleep(1)
return {"type": "http.request"}
request._receive = mock_receive
# is_disconnected() throws exception
request.is_disconnected = AsyncMock(side_effect=Exception("is_disconnected error"))
# Execute and verify exception is re-raised
with pytest.raises(Exception, match="is_disconnected error"):
await check_client_connection(req_id, request)
# ============================================================================
# Edge Cases - setup_disconnect_monitoring
# ============================================================================
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_client_stays_connected():
"""
Test scenario: Client stays connected, result_future completed by other task
Expected: Monitoring task loops normally, executes sleep
"""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=False)
result_future = asyncio.Future()
# Track check calls
check_count = 0
async def mock_check_connected(*args, **kwargs):
nonlocal check_count
check_count += 1
if check_count >= 3:
# Complete the future to stop the loop
if not result_future.done():
result_future.set_result({"status": "success"})
return True # Client stays connected
with patch(
"api_utils.client_connection.check_client_connection",
new_callable=AsyncMock,
side_effect=mock_check_connected,
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait for multiple checks (0.3s sleep each in the monitoring loop)
await asyncio.sleep(1.2)
# Verify: Multiple checks performed
assert check_count >= 3
# Verify: future completed normally
assert result_future.done()
assert result_future.result() == {"status": "success"}
# Verify: event not set (no disconnect)
assert not event.is_set()
# Cleanup
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_setup_disconnect_monitoring_task_cancelled():
"""
Test scenario: Monitoring task cancelled
Expected: CancelledError caught, task exits gracefully
"""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=False)
result_future = asyncio.Future()
# Mock check to return True (connected), so it enters the sleep
with patch(
"api_utils.client_connection.check_client_connection",
new_callable=AsyncMock,
return_value=True,
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Give it time to start one check cycle
await asyncio.sleep(0.1)
# Execute: Cancel task
task.cancel()
# Verify: Task cancelled
# Task catches CancelledError and exits gracefully, will not re-throw
try:
await task
except asyncio.CancelledError:
# If it does raise, that's also fine
pass
# Verify: Task done
assert task.done()
# Verify: event not set (task cancelled, not disconnect)
assert not event.is_set()
@pytest.mark.asyncio
async def test_check_client_disconnected_not_disconnected():
"""
Test scenario: Call check_client_disconnected() but event not set
Expected: Return False, no exception thrown
"""
req_id = "test_req"
request = MagicMock(spec=Request)
request.is_disconnected = AsyncMock(return_value=False)
result_future = asyncio.Future()
# Mock check to keep client connected
with patch(
"api_utils.client_connection.check_client_connection",
new_callable=AsyncMock,
return_value=True,
):
event, task, check_func = await setup_disconnect_monitoring(
req_id, request, result_future
)
# Wait a bit but don't let it disconnect
await asyncio.sleep(0.1)
# Execute: Call check function
result = check_func("test_stage")
# Verify: Return False, no exception thrown
assert result is False
# Verify: event not set
assert not event.is_set()
# Cleanup
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
|