AIstudioProxyAPI / api_utils /client_connection.py
peijun1's picture
Deploy AI Studio Proxy API to Hugging Face Spaces
a5784e9
Raw
History Blame Contribute Delete
7.61 kB
import asyncio
from asyncio import Event, Task
from typing import Any, Callable, Coroutine, Dict, Tuple
from fastapi import HTTPException, Request
from models import ClientDisconnectedError
async def check_client_connection(req_id: str, http_request: Request) -> bool:
"""
Checks if the client is still connected.
Returns True if connected, False if disconnected.
"""
try:
if hasattr(http_request, "_receive"):
try:
# Use a very short timeout to check for disconnect message
# _receive is a private Starlette/FastAPI method that returns a coroutine
receive_obj = http_request # type: ignore[misc]
receive_coro: Coroutine[Any, Any, Dict[str, Any]] = (
receive_obj._receive()
) # type: ignore[misc]
receive_task: Task[Dict[str, Any]] = asyncio.create_task(receive_coro)
done, pending = await asyncio.wait([receive_task], timeout=0.01)
if done:
message = receive_task.result()
if message.get("type") == "http.disconnect":
return False
else:
# Cancel the task if it didn't complete immediately
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
# If it didn't complete immediately, proceed to fallback check
except asyncio.CancelledError:
raise
except Exception:
# If checking fails, proceed to fallback
pass
# Fallback to is_disconnected() if available (Starlette/FastAPI)
# Wrap in wait_for to prevent infinite hang in some ASGI implementations
if hasattr(http_request, "is_disconnected"):
try:
# Handle both sync and async versions for better mock compatibility
res = http_request.is_disconnected()
if asyncio.iscoroutine(res):
if await asyncio.wait_for(res, timeout=0.01):
return False
elif res:
return False
except (asyncio.TimeoutError, asyncio.CancelledError):
# If it times out, it's likely still connected
return True
return True
except asyncio.CancelledError:
raise
except Exception as e:
# Re-raise to allow caller to log/handle
raise e
async def enhanced_disconnect_monitor(
req_id: str,
http_request: Request,
completion_event: asyncio.Event,
logger: Any,
) -> bool:
"""
Monitors for client disconnect during streaming.
Returns True if disconnected, False otherwise.
"""
disconnect_detection_count = 0
while not completion_event.is_set():
try:
is_connected = await check_client_connection(req_id, http_request)
if not is_connected:
disconnect_detection_count += 1
if disconnect_detection_count >= 3:
logger.info(
f"[{req_id}] Client disconnect confirmed during streaming."
)
completion_event.set()
return True
else:
disconnect_detection_count = 0
await asyncio.sleep(0.2)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[{req_id}] Error in enhanced_disconnect_monitor: {e}")
break
return False
async def non_streaming_disconnect_monitor(
req_id: str,
http_request: Request,
result_future: asyncio.Future,
logger: Any,
) -> bool:
"""
Monitors for client disconnect during non-streaming processing.
Returns True if disconnected, False otherwise.
"""
while not result_future.done():
try:
is_connected = await check_client_connection(req_id, http_request)
if not is_connected:
logger.info(
f"[{req_id}] Client disconnect detected during non-streaming."
)
if not result_future.done():
result_future.set_exception(
HTTPException(status_code=499, detail="Client disconnected")
)
return True
await asyncio.sleep(0.3)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[{req_id}] Error in non_streaming_disconnect_monitor: {e}")
break
return False
async def setup_disconnect_monitoring(
req_id: str, http_request: Request, result_future
) -> Tuple[Event, asyncio.Task, Callable]:
from api_utils.server_state import state
logger = state.logger
client_disconnected_event = Event()
disconnect_count = 0
disconnect_threshold = 5 # Require 5 consecutive disconnect signals (1.5 seconds)
async def check_disconnect_periodically():
nonlocal disconnect_count
while not client_disconnected_event.is_set():
try:
is_connected = await check_client_connection(req_id, http_request)
if not is_connected:
disconnect_count += 1
if disconnect_count >= disconnect_threshold:
logger.info(
f"[{req_id}] Active detection of client disconnect (consecutive {disconnect_count} times)."
)
client_disconnected_event.set()
if not result_future.done():
result_future.set_exception(
HTTPException(
status_code=499,
detail=f"[{req_id}] Client closed the request",
)
)
break
else:
logger.debug(
f"[{req_id}] Active detection of potential disconnect (round {disconnect_count}/{disconnect_threshold})"
)
else:
disconnect_count = 0 # Reset counter on successful connection
await asyncio.sleep(0.3)
except asyncio.CancelledError:
# Task cancelled, exit gracefully
break
except Exception as e:
logger.error(f"(Disco Check Task) Error: {e}")
client_disconnected_event.set()
if not result_future.done():
result_future.set_exception(
HTTPException(
status_code=500,
detail=f"[{req_id}] Internal disconnect checker error: {e}",
)
)
break
disconnect_check_task = asyncio.create_task(check_disconnect_periodically())
def check_client_disconnected(stage: str = "") -> bool:
if client_disconnected_event.is_set():
logger.info(f"Client disconnected detected at stage: '{stage}'")
raise ClientDisconnectedError(
f"[{req_id}] Client disconnected at stage: {stage}"
)
return False
return client_disconnected_event, disconnect_check_task, check_client_disconnected