| | """WebSocket client for asyncio.""" |
| |
|
| | import asyncio |
| | import sys |
| | from types import TracebackType |
| | from typing import Any, Optional, Type, cast |
| |
|
| | import attr |
| |
|
| | from ._websocket.reader import WebSocketDataQueue |
| | from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError |
| | from .client_reqrep import ClientResponse |
| | from .helpers import calculate_timeout_when, set_result |
| | from .http import ( |
| | WS_CLOSED_MESSAGE, |
| | WS_CLOSING_MESSAGE, |
| | WebSocketError, |
| | WSCloseCode, |
| | WSMessage, |
| | WSMsgType, |
| | ) |
| | from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter |
| | from .streams import EofStream |
| | from .typedefs import ( |
| | DEFAULT_JSON_DECODER, |
| | DEFAULT_JSON_ENCODER, |
| | JSONDecoder, |
| | JSONEncoder, |
| | ) |
| |
|
| | if sys.version_info >= (3, 11): |
| | import asyncio as async_timeout |
| | else: |
| | import async_timeout |
| |
|
| |
|
| | @attr.s(frozen=True, slots=True) |
| | class ClientWSTimeout: |
| | ws_receive = attr.ib(type=Optional[float], default=None) |
| | ws_close = attr.ib(type=Optional[float], default=None) |
| |
|
| |
|
| | DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0) |
| |
|
| |
|
| | class ClientWebSocketResponse: |
| | def __init__( |
| | self, |
| | reader: WebSocketDataQueue, |
| | writer: WebSocketWriter, |
| | protocol: Optional[str], |
| | response: ClientResponse, |
| | timeout: ClientWSTimeout, |
| | autoclose: bool, |
| | autoping: bool, |
| | loop: asyncio.AbstractEventLoop, |
| | *, |
| | heartbeat: Optional[float] = None, |
| | compress: int = 0, |
| | client_notakeover: bool = False, |
| | ) -> None: |
| | self._response = response |
| | self._conn = response.connection |
| |
|
| | self._writer = writer |
| | self._reader = reader |
| | self._protocol = protocol |
| | self._closed = False |
| | self._closing = False |
| | self._close_code: Optional[int] = None |
| | self._timeout = timeout |
| | self._autoclose = autoclose |
| | self._autoping = autoping |
| | self._heartbeat = heartbeat |
| | self._heartbeat_cb: Optional[asyncio.TimerHandle] = None |
| | self._heartbeat_when: float = 0.0 |
| | if heartbeat is not None: |
| | self._pong_heartbeat = heartbeat / 2.0 |
| | self._pong_response_cb: Optional[asyncio.TimerHandle] = None |
| | self._loop = loop |
| | self._waiting: bool = False |
| | self._close_wait: Optional[asyncio.Future[None]] = None |
| | self._exception: Optional[BaseException] = None |
| | self._compress = compress |
| | self._client_notakeover = client_notakeover |
| | self._ping_task: Optional[asyncio.Task[None]] = None |
| |
|
| | self._reset_heartbeat() |
| |
|
| | def _cancel_heartbeat(self) -> None: |
| | self._cancel_pong_response_cb() |
| | if self._heartbeat_cb is not None: |
| | self._heartbeat_cb.cancel() |
| | self._heartbeat_cb = None |
| | if self._ping_task is not None: |
| | self._ping_task.cancel() |
| | self._ping_task = None |
| |
|
| | def _cancel_pong_response_cb(self) -> None: |
| | if self._pong_response_cb is not None: |
| | self._pong_response_cb.cancel() |
| | self._pong_response_cb = None |
| |
|
| | def _reset_heartbeat(self) -> None: |
| | if self._heartbeat is None: |
| | return |
| | self._cancel_pong_response_cb() |
| | loop = self._loop |
| | assert loop is not None |
| | conn = self._conn |
| | timeout_ceil_threshold = ( |
| | conn._connector._timeout_ceil_threshold if conn is not None else 5 |
| | ) |
| | now = loop.time() |
| | when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) |
| | self._heartbeat_when = when |
| | if self._heartbeat_cb is None: |
| | |
| | |
| | |
| | |
| | |
| | self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) |
| |
|
| | def _send_heartbeat(self) -> None: |
| | self._heartbeat_cb = None |
| | loop = self._loop |
| | now = loop.time() |
| | if now < self._heartbeat_when: |
| | |
| | self._heartbeat_cb = loop.call_at( |
| | self._heartbeat_when, self._send_heartbeat |
| | ) |
| | return |
| |
|
| | conn = self._conn |
| | timeout_ceil_threshold = ( |
| | conn._connector._timeout_ceil_threshold if conn is not None else 5 |
| | ) |
| | when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) |
| | self._cancel_pong_response_cb() |
| | self._pong_response_cb = loop.call_at(when, self._pong_not_received) |
| |
|
| | coro = self._writer.send_frame(b"", WSMsgType.PING) |
| | if sys.version_info >= (3, 12): |
| | |
| | |
| | |
| | ping_task = asyncio.Task(coro, loop=loop, eager_start=True) |
| | else: |
| | ping_task = loop.create_task(coro) |
| |
|
| | if not ping_task.done(): |
| | self._ping_task = ping_task |
| | ping_task.add_done_callback(self._ping_task_done) |
| | else: |
| | self._ping_task_done(ping_task) |
| |
|
| | def _ping_task_done(self, task: "asyncio.Task[None]") -> None: |
| | """Callback for when the ping task completes.""" |
| | if not task.cancelled() and (exc := task.exception()): |
| | self._handle_ping_pong_exception(exc) |
| | self._ping_task = None |
| |
|
| | def _pong_not_received(self) -> None: |
| | self._handle_ping_pong_exception( |
| | ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") |
| | ) |
| |
|
| | def _handle_ping_pong_exception(self, exc: BaseException) -> None: |
| | """Handle exceptions raised during ping/pong processing.""" |
| | if self._closed: |
| | return |
| | self._set_closed() |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | self._exception = exc |
| | self._response.close() |
| | if self._waiting and not self._closing: |
| | self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0) |
| |
|
| | def _set_closed(self) -> None: |
| | """Set the connection to closed. |
| | |
| | Cancel any heartbeat timers and set the closed flag. |
| | """ |
| | self._closed = True |
| | self._cancel_heartbeat() |
| |
|
| | def _set_closing(self) -> None: |
| | """Set the connection to closing. |
| | |
| | Cancel any heartbeat timers and set the closing flag. |
| | """ |
| | self._closing = True |
| | self._cancel_heartbeat() |
| |
|
| | @property |
| | def closed(self) -> bool: |
| | return self._closed |
| |
|
| | @property |
| | def close_code(self) -> Optional[int]: |
| | return self._close_code |
| |
|
| | @property |
| | def protocol(self) -> Optional[str]: |
| | return self._protocol |
| |
|
| | @property |
| | def compress(self) -> int: |
| | return self._compress |
| |
|
| | @property |
| | def client_notakeover(self) -> bool: |
| | return self._client_notakeover |
| |
|
| | def get_extra_info(self, name: str, default: Any = None) -> Any: |
| | """extra info from connection transport""" |
| | conn = self._response.connection |
| | if conn is None: |
| | return default |
| | transport = conn.transport |
| | if transport is None: |
| | return default |
| | return transport.get_extra_info(name, default) |
| |
|
| | def exception(self) -> Optional[BaseException]: |
| | return self._exception |
| |
|
| | async def ping(self, message: bytes = b"") -> None: |
| | await self._writer.send_frame(message, WSMsgType.PING) |
| |
|
| | async def pong(self, message: bytes = b"") -> None: |
| | await self._writer.send_frame(message, WSMsgType.PONG) |
| |
|
| | async def send_frame( |
| | self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None |
| | ) -> None: |
| | """Send a frame over the websocket.""" |
| | await self._writer.send_frame(message, opcode, compress) |
| |
|
| | async def send_str(self, data: str, compress: Optional[int] = None) -> None: |
| | if not isinstance(data, str): |
| | raise TypeError("data argument must be str (%r)" % type(data)) |
| | await self._writer.send_frame( |
| | data.encode("utf-8"), WSMsgType.TEXT, compress=compress |
| | ) |
| |
|
| | async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: |
| | if not isinstance(data, (bytes, bytearray, memoryview)): |
| | raise TypeError("data argument must be byte-ish (%r)" % type(data)) |
| | await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) |
| |
|
| | async def send_json( |
| | self, |
| | data: Any, |
| | compress: Optional[int] = None, |
| | *, |
| | dumps: JSONEncoder = DEFAULT_JSON_ENCODER, |
| | ) -> None: |
| | await self.send_str(dumps(data), compress=compress) |
| |
|
| | async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: |
| | |
| | |
| | if self._waiting and not self._closing: |
| | assert self._loop is not None |
| | self._close_wait = self._loop.create_future() |
| | self._set_closing() |
| | self._reader.feed_data(WS_CLOSING_MESSAGE, 0) |
| | await self._close_wait |
| |
|
| | if self._closed: |
| | return False |
| |
|
| | self._set_closed() |
| | try: |
| | await self._writer.close(code, message) |
| | except asyncio.CancelledError: |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | self._response.close() |
| | raise |
| | except Exception as exc: |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | self._exception = exc |
| | self._response.close() |
| | return True |
| |
|
| | if self._close_code: |
| | self._response.close() |
| | return True |
| |
|
| | while True: |
| | try: |
| | async with async_timeout.timeout(self._timeout.ws_close): |
| | msg = await self._reader.read() |
| | except asyncio.CancelledError: |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | self._response.close() |
| | raise |
| | except Exception as exc: |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | self._exception = exc |
| | self._response.close() |
| | return True |
| |
|
| | if msg.type is WSMsgType.CLOSE: |
| | self._close_code = msg.data |
| | self._response.close() |
| | return True |
| |
|
| | async def receive(self, timeout: Optional[float] = None) -> WSMessage: |
| | receive_timeout = timeout or self._timeout.ws_receive |
| |
|
| | while True: |
| | if self._waiting: |
| | raise RuntimeError("Concurrent call to receive() is not allowed") |
| |
|
| | if self._closed: |
| | return WS_CLOSED_MESSAGE |
| | elif self._closing: |
| | await self.close() |
| | return WS_CLOSED_MESSAGE |
| |
|
| | try: |
| | self._waiting = True |
| | try: |
| | if receive_timeout: |
| | |
| | |
| | |
| | |
| | async with async_timeout.timeout(receive_timeout): |
| | msg = await self._reader.read() |
| | else: |
| | msg = await self._reader.read() |
| | self._reset_heartbeat() |
| | finally: |
| | self._waiting = False |
| | if self._close_wait: |
| | set_result(self._close_wait, None) |
| | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | raise |
| | except EofStream: |
| | self._close_code = WSCloseCode.OK |
| | await self.close() |
| | return WSMessage(WSMsgType.CLOSED, None, None) |
| | except ClientError: |
| | |
| | self._set_closed() |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | return WS_CLOSED_MESSAGE |
| | except WebSocketError as exc: |
| | self._close_code = exc.code |
| | await self.close(code=exc.code) |
| | return WSMessage(WSMsgType.ERROR, exc, None) |
| | except Exception as exc: |
| | self._exception = exc |
| | self._set_closing() |
| | self._close_code = WSCloseCode.ABNORMAL_CLOSURE |
| | await self.close() |
| | return WSMessage(WSMsgType.ERROR, exc, None) |
| |
|
| | if msg.type not in _INTERNAL_RECEIVE_TYPES: |
| | |
| | |
| | return msg |
| |
|
| | if msg.type is WSMsgType.CLOSE: |
| | self._set_closing() |
| | self._close_code = msg.data |
| | if not self._closed and self._autoclose: |
| | await self.close() |
| | elif msg.type is WSMsgType.CLOSING: |
| | self._set_closing() |
| | elif msg.type is WSMsgType.PING and self._autoping: |
| | await self.pong(msg.data) |
| | continue |
| | elif msg.type is WSMsgType.PONG and self._autoping: |
| | continue |
| |
|
| | return msg |
| |
|
| | async def receive_str(self, *, timeout: Optional[float] = None) -> str: |
| | msg = await self.receive(timeout) |
| | if msg.type is not WSMsgType.TEXT: |
| | raise WSMessageTypeError( |
| | f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" |
| | ) |
| | return cast(str, msg.data) |
| |
|
| | async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: |
| | msg = await self.receive(timeout) |
| | if msg.type is not WSMsgType.BINARY: |
| | raise WSMessageTypeError( |
| | f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" |
| | ) |
| | return cast(bytes, msg.data) |
| |
|
| | async def receive_json( |
| | self, |
| | *, |
| | loads: JSONDecoder = DEFAULT_JSON_DECODER, |
| | timeout: Optional[float] = None, |
| | ) -> Any: |
| | data = await self.receive_str(timeout=timeout) |
| | return loads(data) |
| |
|
| | def __aiter__(self) -> "ClientWebSocketResponse": |
| | return self |
| |
|
| | async def __anext__(self) -> WSMessage: |
| | msg = await self.receive() |
| | if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): |
| | raise StopAsyncIteration |
| | return msg |
| |
|
| | async def __aenter__(self) -> "ClientWebSocketResponse": |
| | return self |
| |
|
| | async def __aexit__( |
| | self, |
| | exc_type: Optional[Type[BaseException]], |
| | exc_val: Optional[BaseException], |
| | exc_tb: Optional[TracebackType], |
| | ) -> None: |
| | await self.close() |
| |
|