Spaces:
Paused
Paused
| import asyncio | |
| from asyncio import Event, Task | |
| from typing import Any, Callable, Coroutine, Dict, Tuple | |
| from fastapi import HTTPException, Request | |
| from models import ClientDisconnectedError | |
| async def check_client_connection(req_id: str, http_request: Request) -> bool: | |
| """ | |
| Checks if the client is still connected. | |
| Returns True if connected, False if disconnected. | |
| """ | |
| try: | |
| if hasattr(http_request, "_receive"): | |
| try: | |
| # Use a very short timeout to check for disconnect message | |
| # _receive is a private Starlette/FastAPI method that returns a coroutine | |
| receive_obj = http_request # type: ignore[misc] | |
| receive_coro: Coroutine[Any, Any, Dict[str, Any]] = ( | |
| receive_obj._receive() | |
| ) # type: ignore[misc] | |
| receive_task: Task[Dict[str, Any]] = asyncio.create_task(receive_coro) | |
| done, pending = await asyncio.wait([receive_task], timeout=0.01) | |
| if done: | |
| message = receive_task.result() | |
| if message.get("type") == "http.disconnect": | |
| return False | |
| else: | |
| # Cancel the task if it didn't complete immediately | |
| receive_task.cancel() | |
| try: | |
| await receive_task | |
| except asyncio.CancelledError: | |
| pass | |
| # If it didn't complete immediately, proceed to fallback check | |
| except asyncio.CancelledError: | |
| raise | |
| except Exception: | |
| # If checking fails, proceed to fallback | |
| pass | |
| # Fallback to is_disconnected() if available (Starlette/FastAPI) | |
| # Wrap in wait_for to prevent infinite hang in some ASGI implementations | |
| if hasattr(http_request, "is_disconnected"): | |
| try: | |
| # Handle both sync and async versions for better mock compatibility | |
| res = http_request.is_disconnected() | |
| if asyncio.iscoroutine(res): | |
| if await asyncio.wait_for(res, timeout=0.01): | |
| return False | |
| elif res: | |
| return False | |
| except (asyncio.TimeoutError, asyncio.CancelledError): | |
| # If it times out, it's likely still connected | |
| return True | |
| return True | |
| except asyncio.CancelledError: | |
| raise | |
| except Exception as e: | |
| # Re-raise to allow caller to log/handle | |
| raise e | |
| async def enhanced_disconnect_monitor( | |
| req_id: str, | |
| http_request: Request, | |
| completion_event: asyncio.Event, | |
| logger: Any, | |
| ) -> bool: | |
| """ | |
| Monitors for client disconnect during streaming. | |
| Returns True if disconnected, False otherwise. | |
| """ | |
| disconnect_detection_count = 0 | |
| while not completion_event.is_set(): | |
| try: | |
| is_connected = await check_client_connection(req_id, http_request) | |
| if not is_connected: | |
| disconnect_detection_count += 1 | |
| if disconnect_detection_count >= 3: | |
| logger.info( | |
| f"[{req_id}] Client disconnect confirmed during streaming." | |
| ) | |
| completion_event.set() | |
| return True | |
| else: | |
| disconnect_detection_count = 0 | |
| await asyncio.sleep(0.2) | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"[{req_id}] Error in enhanced_disconnect_monitor: {e}") | |
| break | |
| return False | |
| async def non_streaming_disconnect_monitor( | |
| req_id: str, | |
| http_request: Request, | |
| result_future: asyncio.Future, | |
| logger: Any, | |
| ) -> bool: | |
| """ | |
| Monitors for client disconnect during non-streaming processing. | |
| Returns True if disconnected, False otherwise. | |
| """ | |
| while not result_future.done(): | |
| try: | |
| is_connected = await check_client_connection(req_id, http_request) | |
| if not is_connected: | |
| logger.info( | |
| f"[{req_id}] Client disconnect detected during non-streaming." | |
| ) | |
| if not result_future.done(): | |
| result_future.set_exception( | |
| HTTPException(status_code=499, detail="Client disconnected") | |
| ) | |
| return True | |
| await asyncio.sleep(0.3) | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"[{req_id}] Error in non_streaming_disconnect_monitor: {e}") | |
| break | |
| return False | |
| async def setup_disconnect_monitoring( | |
| req_id: str, http_request: Request, result_future | |
| ) -> Tuple[Event, asyncio.Task, Callable]: | |
| from api_utils.server_state import state | |
| logger = state.logger | |
| client_disconnected_event = Event() | |
| disconnect_count = 0 | |
| disconnect_threshold = 5 # Require 5 consecutive disconnect signals (1.5 seconds) | |
| async def check_disconnect_periodically(): | |
| nonlocal disconnect_count | |
| while not client_disconnected_event.is_set(): | |
| try: | |
| is_connected = await check_client_connection(req_id, http_request) | |
| if not is_connected: | |
| disconnect_count += 1 | |
| if disconnect_count >= disconnect_threshold: | |
| logger.info( | |
| f"[{req_id}] Active detection of client disconnect (consecutive {disconnect_count} times)." | |
| ) | |
| client_disconnected_event.set() | |
| if not result_future.done(): | |
| result_future.set_exception( | |
| HTTPException( | |
| status_code=499, | |
| detail=f"[{req_id}] Client closed the request", | |
| ) | |
| ) | |
| break | |
| else: | |
| logger.debug( | |
| f"[{req_id}] Active detection of potential disconnect (round {disconnect_count}/{disconnect_threshold})" | |
| ) | |
| else: | |
| disconnect_count = 0 # Reset counter on successful connection | |
| await asyncio.sleep(0.3) | |
| except asyncio.CancelledError: | |
| # Task cancelled, exit gracefully | |
| break | |
| except Exception as e: | |
| logger.error(f"(Disco Check Task) Error: {e}") | |
| client_disconnected_event.set() | |
| if not result_future.done(): | |
| result_future.set_exception( | |
| HTTPException( | |
| status_code=500, | |
| detail=f"[{req_id}] Internal disconnect checker error: {e}", | |
| ) | |
| ) | |
| break | |
| disconnect_check_task = asyncio.create_task(check_disconnect_periodically()) | |
| def check_client_disconnected(stage: str = "") -> bool: | |
| if client_disconnected_event.is_set(): | |
| logger.info(f"Client disconnected detected at stage: '{stage}'") | |
| raise ClientDisconnectedError( | |
| f"[{req_id}] Client disconnected at stage: {stage}" | |
| ) | |
| return False | |
| return client_disconnected_event, disconnect_check_task, check_client_disconnected | |