Spaces:
Paused
Paused
File size: 7,606 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | import asyncio
from asyncio import Event, Task
from typing import Any, Callable, Coroutine, Dict, Tuple
from fastapi import HTTPException, Request
from models import ClientDisconnectedError
async def check_client_connection(req_id: str, http_request: Request) -> bool:
"""
Checks if the client is still connected.
Returns True if connected, False if disconnected.
"""
try:
if hasattr(http_request, "_receive"):
try:
# Use a very short timeout to check for disconnect message
# _receive is a private Starlette/FastAPI method that returns a coroutine
receive_obj = http_request # type: ignore[misc]
receive_coro: Coroutine[Any, Any, Dict[str, Any]] = (
receive_obj._receive()
) # type: ignore[misc]
receive_task: Task[Dict[str, Any]] = asyncio.create_task(receive_coro)
done, pending = await asyncio.wait([receive_task], timeout=0.01)
if done:
message = receive_task.result()
if message.get("type") == "http.disconnect":
return False
else:
# Cancel the task if it didn't complete immediately
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
# If it didn't complete immediately, proceed to fallback check
except asyncio.CancelledError:
raise
except Exception:
# If checking fails, proceed to fallback
pass
# Fallback to is_disconnected() if available (Starlette/FastAPI)
# Wrap in wait_for to prevent infinite hang in some ASGI implementations
if hasattr(http_request, "is_disconnected"):
try:
# Handle both sync and async versions for better mock compatibility
res = http_request.is_disconnected()
if asyncio.iscoroutine(res):
if await asyncio.wait_for(res, timeout=0.01):
return False
elif res:
return False
except (asyncio.TimeoutError, asyncio.CancelledError):
# If it times out, it's likely still connected
return True
return True
except asyncio.CancelledError:
raise
except Exception as e:
# Re-raise to allow caller to log/handle
raise e
async def enhanced_disconnect_monitor(
req_id: str,
http_request: Request,
completion_event: asyncio.Event,
logger: Any,
) -> bool:
"""
Monitors for client disconnect during streaming.
Returns True if disconnected, False otherwise.
"""
disconnect_detection_count = 0
while not completion_event.is_set():
try:
is_connected = await check_client_connection(req_id, http_request)
if not is_connected:
disconnect_detection_count += 1
if disconnect_detection_count >= 3:
logger.info(
f"[{req_id}] Client disconnect confirmed during streaming."
)
completion_event.set()
return True
else:
disconnect_detection_count = 0
await asyncio.sleep(0.2)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[{req_id}] Error in enhanced_disconnect_monitor: {e}")
break
return False
async def non_streaming_disconnect_monitor(
req_id: str,
http_request: Request,
result_future: asyncio.Future,
logger: Any,
) -> bool:
"""
Monitors for client disconnect during non-streaming processing.
Returns True if disconnected, False otherwise.
"""
while not result_future.done():
try:
is_connected = await check_client_connection(req_id, http_request)
if not is_connected:
logger.info(
f"[{req_id}] Client disconnect detected during non-streaming."
)
if not result_future.done():
result_future.set_exception(
HTTPException(status_code=499, detail="Client disconnected")
)
return True
await asyncio.sleep(0.3)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[{req_id}] Error in non_streaming_disconnect_monitor: {e}")
break
return False
async def setup_disconnect_monitoring(
req_id: str, http_request: Request, result_future
) -> Tuple[Event, asyncio.Task, Callable]:
from api_utils.server_state import state
logger = state.logger
client_disconnected_event = Event()
disconnect_count = 0
disconnect_threshold = 5 # Require 5 consecutive disconnect signals (1.5 seconds)
async def check_disconnect_periodically():
nonlocal disconnect_count
while not client_disconnected_event.is_set():
try:
is_connected = await check_client_connection(req_id, http_request)
if not is_connected:
disconnect_count += 1
if disconnect_count >= disconnect_threshold:
logger.info(
f"[{req_id}] Active detection of client disconnect (consecutive {disconnect_count} times)."
)
client_disconnected_event.set()
if not result_future.done():
result_future.set_exception(
HTTPException(
status_code=499,
detail=f"[{req_id}] Client closed the request",
)
)
break
else:
logger.debug(
f"[{req_id}] Active detection of potential disconnect (round {disconnect_count}/{disconnect_threshold})"
)
else:
disconnect_count = 0 # Reset counter on successful connection
await asyncio.sleep(0.3)
except asyncio.CancelledError:
# Task cancelled, exit gracefully
break
except Exception as e:
logger.error(f"(Disco Check Task) Error: {e}")
client_disconnected_event.set()
if not result_future.done():
result_future.set_exception(
HTTPException(
status_code=500,
detail=f"[{req_id}] Internal disconnect checker error: {e}",
)
)
break
disconnect_check_task = asyncio.create_task(check_disconnect_periodically())
def check_client_disconnected(stage: str = "") -> bool:
if client_disconnected_event.is_set():
logger.info(f"Client disconnected detected at stage: '{stage}'")
raise ClientDisconnectedError(
f"[{req_id}] Client disconnected at stage: {stage}"
)
return False
return client_disconnected_event, disconnect_check_task, check_client_disconnected
|