File size: 7,606 Bytes
a5784e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import asyncio
from asyncio import Event, Task
from typing import Any, Callable, Coroutine, Dict, Tuple

from fastapi import HTTPException, Request

from models import ClientDisconnectedError


async def check_client_connection(req_id: str, http_request: Request) -> bool:
    """
    Checks if the client is still connected.
    Returns True if connected, False if disconnected.
    """
    try:
        if hasattr(http_request, "_receive"):
            try:
                # Use a very short timeout to check for disconnect message
                # _receive is a private Starlette/FastAPI method that returns a coroutine
                receive_obj = http_request  # type: ignore[misc]
                receive_coro: Coroutine[Any, Any, Dict[str, Any]] = (
                    receive_obj._receive()
                )  # type: ignore[misc]
                receive_task: Task[Dict[str, Any]] = asyncio.create_task(receive_coro)
                done, pending = await asyncio.wait([receive_task], timeout=0.01)

                if done:
                    message = receive_task.result()
                    if message.get("type") == "http.disconnect":
                        return False
                else:
                    # Cancel the task if it didn't complete immediately
                    receive_task.cancel()
                    try:
                        await receive_task
                    except asyncio.CancelledError:
                        pass
                    # If it didn't complete immediately, proceed to fallback check
            except asyncio.CancelledError:
                raise
            except Exception:
                # If checking fails, proceed to fallback
                pass

        # Fallback to is_disconnected() if available (Starlette/FastAPI)
        # Wrap in wait_for to prevent infinite hang in some ASGI implementations
        if hasattr(http_request, "is_disconnected"):
            try:
                # Handle both sync and async versions for better mock compatibility
                res = http_request.is_disconnected()
                if asyncio.iscoroutine(res):
                    if await asyncio.wait_for(res, timeout=0.01):
                        return False
                elif res:
                    return False
            except (asyncio.TimeoutError, asyncio.CancelledError):
                # If it times out, it's likely still connected
                return True

        return True
    except asyncio.CancelledError:
        raise
    except Exception as e:
        # Re-raise to allow caller to log/handle
        raise e


async def enhanced_disconnect_monitor(
    req_id: str,
    http_request: Request,
    completion_event: asyncio.Event,
    logger: Any,
) -> bool:
    """
    Monitors for client disconnect during streaming.
    Returns True if disconnected, False otherwise.
    """
    disconnect_detection_count = 0
    while not completion_event.is_set():
        try:
            is_connected = await check_client_connection(req_id, http_request)
            if not is_connected:
                disconnect_detection_count += 1
                if disconnect_detection_count >= 3:
                    logger.info(
                        f"[{req_id}] Client disconnect confirmed during streaming."
                    )
                    completion_event.set()
                    return True
            else:
                disconnect_detection_count = 0
            await asyncio.sleep(0.2)
        except asyncio.CancelledError:
            break
        except Exception as e:
            logger.error(f"[{req_id}] Error in enhanced_disconnect_monitor: {e}")
            break
    return False


async def non_streaming_disconnect_monitor(
    req_id: str,
    http_request: Request,
    result_future: asyncio.Future,
    logger: Any,
) -> bool:
    """
    Monitors for client disconnect during non-streaming processing.
    Returns True if disconnected, False otherwise.
    """
    while not result_future.done():
        try:
            is_connected = await check_client_connection(req_id, http_request)
            if not is_connected:
                logger.info(
                    f"[{req_id}] Client disconnect detected during non-streaming."
                )
                if not result_future.done():
                    result_future.set_exception(
                        HTTPException(status_code=499, detail="Client disconnected")
                    )
                return True
            await asyncio.sleep(0.3)
        except asyncio.CancelledError:
            break
        except Exception as e:
            logger.error(f"[{req_id}] Error in non_streaming_disconnect_monitor: {e}")
            break
    return False


async def setup_disconnect_monitoring(
    req_id: str, http_request: Request, result_future
) -> Tuple[Event, asyncio.Task, Callable]:
    from api_utils.server_state import state

    logger = state.logger

    client_disconnected_event = Event()
    disconnect_count = 0
    disconnect_threshold = 5  # Require 5 consecutive disconnect signals (1.5 seconds)

    async def check_disconnect_periodically():
        nonlocal disconnect_count
        while not client_disconnected_event.is_set():
            try:
                is_connected = await check_client_connection(req_id, http_request)
                if not is_connected:
                    disconnect_count += 1
                    if disconnect_count >= disconnect_threshold:
                        logger.info(
                            f"[{req_id}] Active detection of client disconnect (consecutive {disconnect_count} times)."
                        )
                        client_disconnected_event.set()
                        if not result_future.done():
                            result_future.set_exception(
                                HTTPException(
                                    status_code=499,
                                    detail=f"[{req_id}] Client closed the request",
                                )
                            )
                        break
                    else:
                        logger.debug(
                            f"[{req_id}] Active detection of potential disconnect (round {disconnect_count}/{disconnect_threshold})"
                        )
                else:
                    disconnect_count = 0  # Reset counter on successful connection

                await asyncio.sleep(0.3)
            except asyncio.CancelledError:
                # Task cancelled, exit gracefully
                break
            except Exception as e:
                logger.error(f"(Disco Check Task) Error: {e}")
                client_disconnected_event.set()
                if not result_future.done():
                    result_future.set_exception(
                        HTTPException(
                            status_code=500,
                            detail=f"[{req_id}] Internal disconnect checker error: {e}",
                        )
                    )
                break

    disconnect_check_task = asyncio.create_task(check_disconnect_periodically())

    def check_client_disconnected(stage: str = "") -> bool:
        if client_disconnected_event.is_set():
            logger.info(f"Client disconnected detected at stage: '{stage}'")
            raise ClientDisconnectedError(
                f"[{req_id}] Client disconnected at stage: {stage}"
            )
        return False

    return client_disconnected_event, disconnect_check_task, check_client_disconnected