| | """Helpers for WebSocket protocol versions 13 and 8.""" |
| |
|
| | import functools |
| | import re |
| | from struct import Struct |
| | from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple |
| |
|
| | from ..helpers import NO_EXTENSIONS |
| | from .models import WSHandshakeError |
| |
|
| | UNPACK_LEN3 = Struct("!Q").unpack_from |
| | UNPACK_CLOSE_CODE = Struct("!H").unpack |
| | PACK_LEN1 = Struct("!BB").pack |
| | PACK_LEN2 = Struct("!BBH").pack |
| | PACK_LEN3 = Struct("!BBQ").pack |
| | PACK_CLOSE_CODE = Struct("!H").pack |
| | PACK_RANDBITS = Struct("!L").pack |
| | MSG_SIZE: Final[int] = 2**14 |
| | MASK_LEN: Final[int] = 4 |
| |
|
| | WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" |
| |
|
| |
|
| | |
| | @functools.lru_cache |
| | def _xor_table() -> List[bytes]: |
| | return [bytes(a ^ b for a in range(256)) for b in range(256)] |
| |
|
| |
|
| | def _websocket_mask_python(mask: bytes, data: bytearray) -> None: |
| | """Websocket masking function. |
| | |
| | `mask` is a `bytes` object of length 4; `data` is a `bytearray` |
| | object of any length. The contents of `data` are masked with `mask`, |
| | as specified in section 5.3 of RFC 6455. |
| | |
| | Note that this function mutates the `data` argument. |
| | |
| | This pure-python implementation may be replaced by an optimized |
| | version when available. |
| | |
| | """ |
| | assert isinstance(data, bytearray), data |
| | assert len(mask) == 4, mask |
| |
|
| | if data: |
| | _XOR_TABLE = _xor_table() |
| | a, b, c, d = (_XOR_TABLE[n] for n in mask) |
| | data[::4] = data[::4].translate(a) |
| | data[1::4] = data[1::4].translate(b) |
| | data[2::4] = data[2::4].translate(c) |
| | data[3::4] = data[3::4].translate(d) |
| |
|
| |
|
| | if TYPE_CHECKING or NO_EXTENSIONS: |
| | websocket_mask = _websocket_mask_python |
| | else: |
| | try: |
| | from .mask import _websocket_mask_cython |
| |
|
| | websocket_mask = _websocket_mask_cython |
| | except ImportError: |
| | websocket_mask = _websocket_mask_python |
| |
|
| |
|
| | _WS_EXT_RE: Final[Pattern[str]] = re.compile( |
| | r"^(?:;\s*(?:" |
| | r"(server_no_context_takeover)|" |
| | r"(client_no_context_takeover)|" |
| | r"(server_max_window_bits(?:=(\d+))?)|" |
| | r"(client_max_window_bits(?:=(\d+))?)))*$" |
| | ) |
| |
|
| | _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") |
| |
|
| |
|
| | def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: |
| | if not extstr: |
| | return 0, False |
| |
|
| | compress = 0 |
| | notakeover = False |
| | for ext in _WS_EXT_RE_SPLIT.finditer(extstr): |
| | defext = ext.group(1) |
| | |
| | if not defext: |
| | compress = 15 |
| | break |
| | match = _WS_EXT_RE.match(defext) |
| | if match: |
| | compress = 15 |
| | if isserver: |
| | |
| | |
| | if match.group(4): |
| | compress = int(match.group(4)) |
| | |
| | |
| | |
| | |
| | if compress > 15 or compress < 9: |
| | compress = 0 |
| | continue |
| | if match.group(1): |
| | notakeover = True |
| | |
| | break |
| | else: |
| | if match.group(6): |
| | compress = int(match.group(6)) |
| | |
| | |
| | |
| | |
| | if compress > 15 or compress < 9: |
| | raise WSHandshakeError("Invalid window size") |
| | if match.group(2): |
| | notakeover = True |
| | |
| | break |
| | |
| | elif not isserver: |
| | raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) |
| |
|
| | return compress, notakeover |
| |
|
| |
|
| | def ws_ext_gen( |
| | compress: int = 15, isserver: bool = False, server_notakeover: bool = False |
| | ) -> str: |
| | |
| | |
| | if compress < 9 or compress > 15: |
| | raise ValueError( |
| | "Compress wbits must between 9 and 15, zlib does not support wbits=8" |
| | ) |
| | enabledext = ["permessage-deflate"] |
| | if not isserver: |
| | enabledext.append("client_max_window_bits") |
| |
|
| | if compress < 15: |
| | enabledext.append("server_max_window_bits=" + str(compress)) |
| | if server_notakeover: |
| | enabledext.append("server_no_context_takeover") |
| | |
| | |
| | return "; ".join(enabledext) |
| |
|