| | from __future__ import annotations |
| |
|
| | import io |
| | import socket |
| | import ssl |
| | import typing |
| |
|
| | from ..exceptions import ProxySchemeUnsupported |
| |
|
| | if typing.TYPE_CHECKING: |
| | from typing import Literal |
| |
|
| | from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT |
| |
|
| |
|
| | _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") |
| | _WriteBuffer = typing.Union[bytearray, memoryview] |
| | _ReturnValue = typing.TypeVar("_ReturnValue") |
| |
|
| | SSL_BLOCKSIZE = 16384 |
| |
|
| |
|
| | class SSLTransport: |
| | """ |
| | The SSLTransport wraps an existing socket and establishes an SSL connection. |
| | |
| | Contrary to Python's implementation of SSLSocket, it allows you to chain |
| | multiple TLS connections together. It's particularly useful if you need to |
| | implement TLS within TLS. |
| | |
| | The class supports most of the socket API operations. |
| | """ |
| |
|
| | @staticmethod |
| | def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: |
| | """ |
| | Raises a ProxySchemeUnsupported if the provided ssl_context can't be used |
| | for TLS in TLS. |
| | |
| | The only requirement is that the ssl_context provides the 'wrap_bio' |
| | methods. |
| | """ |
| |
|
| | if not hasattr(ssl_context, "wrap_bio"): |
| | raise ProxySchemeUnsupported( |
| | "TLS in TLS requires SSLContext.wrap_bio() which isn't " |
| | "available on non-native SSLContext" |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | socket: socket.socket, |
| | ssl_context: ssl.SSLContext, |
| | server_hostname: str | None = None, |
| | suppress_ragged_eofs: bool = True, |
| | ) -> None: |
| | """ |
| | Create an SSLTransport around socket using the provided ssl_context. |
| | """ |
| | self.incoming = ssl.MemoryBIO() |
| | self.outgoing = ssl.MemoryBIO() |
| |
|
| | self.suppress_ragged_eofs = suppress_ragged_eofs |
| | self.socket = socket |
| |
|
| | self.sslobj = ssl_context.wrap_bio( |
| | self.incoming, self.outgoing, server_hostname=server_hostname |
| | ) |
| |
|
| | |
| | self._ssl_io_loop(self.sslobj.do_handshake) |
| |
|
| | def __enter__(self: _SelfT) -> _SelfT: |
| | return self |
| |
|
| | def __exit__(self, *_: typing.Any) -> None: |
| | self.close() |
| |
|
| | def fileno(self) -> int: |
| | return self.socket.fileno() |
| |
|
| | def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: |
| | return self._wrap_ssl_read(len, buffer) |
| |
|
| | def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: |
| | if flags != 0: |
| | raise ValueError("non-zero flags not allowed in calls to recv") |
| | return self._wrap_ssl_read(buflen) |
| |
|
| | def recv_into( |
| | self, |
| | buffer: _WriteBuffer, |
| | nbytes: int | None = None, |
| | flags: int = 0, |
| | ) -> None | int | bytes: |
| | if flags != 0: |
| | raise ValueError("non-zero flags not allowed in calls to recv_into") |
| | if nbytes is None: |
| | nbytes = len(buffer) |
| | return self.read(nbytes, buffer) |
| |
|
| | def sendall(self, data: bytes, flags: int = 0) -> None: |
| | if flags != 0: |
| | raise ValueError("non-zero flags not allowed in calls to sendall") |
| | count = 0 |
| | with memoryview(data) as view, view.cast("B") as byte_view: |
| | amount = len(byte_view) |
| | while count < amount: |
| | v = self.send(byte_view[count:]) |
| | count += v |
| |
|
| | def send(self, data: bytes, flags: int = 0) -> int: |
| | if flags != 0: |
| | raise ValueError("non-zero flags not allowed in calls to send") |
| | return self._ssl_io_loop(self.sslobj.write, data) |
| |
|
| | def makefile( |
| | self, |
| | mode: str, |
| | buffering: int | None = None, |
| | *, |
| | encoding: str | None = None, |
| | errors: str | None = None, |
| | newline: str | None = None, |
| | ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: |
| | """ |
| | Python's httpclient uses makefile and buffered io when reading HTTP |
| | messages and we need to support it. |
| | |
| | This is unfortunately a copy and paste of socket.py makefile with small |
| | changes to point to the socket directly. |
| | """ |
| | if not set(mode) <= {"r", "w", "b"}: |
| | raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") |
| |
|
| | writing = "w" in mode |
| | reading = "r" in mode or not writing |
| | assert reading or writing |
| | binary = "b" in mode |
| | rawmode = "" |
| | if reading: |
| | rawmode += "r" |
| | if writing: |
| | rawmode += "w" |
| | raw = socket.SocketIO(self, rawmode) |
| | self.socket._io_refs += 1 |
| | if buffering is None: |
| | buffering = -1 |
| | if buffering < 0: |
| | buffering = io.DEFAULT_BUFFER_SIZE |
| | if buffering == 0: |
| | if not binary: |
| | raise ValueError("unbuffered streams must be binary") |
| | return raw |
| | buffer: typing.BinaryIO |
| | if reading and writing: |
| | buffer = io.BufferedRWPair(raw, raw, buffering) |
| | elif reading: |
| | buffer = io.BufferedReader(raw, buffering) |
| | else: |
| | assert writing |
| | buffer = io.BufferedWriter(raw, buffering) |
| | if binary: |
| | return buffer |
| | text = io.TextIOWrapper(buffer, encoding, errors, newline) |
| | text.mode = mode |
| | return text |
| |
|
| | def unwrap(self) -> None: |
| | self._ssl_io_loop(self.sslobj.unwrap) |
| |
|
| | def close(self) -> None: |
| | self.socket.close() |
| |
|
| | @typing.overload |
| | def getpeercert( |
| | self, binary_form: Literal[False] = ... |
| | ) -> _TYPE_PEER_CERT_RET_DICT | None: |
| | ... |
| |
|
| | @typing.overload |
| | def getpeercert(self, binary_form: Literal[True]) -> bytes | None: |
| | ... |
| |
|
| | def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: |
| | return self.sslobj.getpeercert(binary_form) |
| |
|
| | def version(self) -> str | None: |
| | return self.sslobj.version() |
| |
|
| | def cipher(self) -> tuple[str, str, int] | None: |
| | return self.sslobj.cipher() |
| |
|
| | def selected_alpn_protocol(self) -> str | None: |
| | return self.sslobj.selected_alpn_protocol() |
| |
|
| | def selected_npn_protocol(self) -> str | None: |
| | return self.sslobj.selected_npn_protocol() |
| |
|
| | def shared_ciphers(self) -> list[tuple[str, str, int]] | None: |
| | return self.sslobj.shared_ciphers() |
| |
|
| | def compression(self) -> str | None: |
| | return self.sslobj.compression() |
| |
|
| | def settimeout(self, value: float | None) -> None: |
| | self.socket.settimeout(value) |
| |
|
| | def gettimeout(self) -> float | None: |
| | return self.socket.gettimeout() |
| |
|
| | def _decref_socketios(self) -> None: |
| | self.socket._decref_socketios() |
| |
|
| | def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: |
| | try: |
| | return self._ssl_io_loop(self.sslobj.read, len, buffer) |
| | except ssl.SSLError as e: |
| | if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: |
| | return 0 |
| | else: |
| | raise |
| |
|
| | |
| | @typing.overload |
| | def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: |
| | ... |
| |
|
| | |
| | @typing.overload |
| | def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: |
| | ... |
| |
|
| | |
| | @typing.overload |
| | def _ssl_io_loop( |
| | self, |
| | func: typing.Callable[[int, bytearray | None], bytes], |
| | arg1: int, |
| | arg2: bytearray | None, |
| | ) -> bytes: |
| | ... |
| |
|
| | def _ssl_io_loop( |
| | self, |
| | func: typing.Callable[..., _ReturnValue], |
| | arg1: None | bytes | int = None, |
| | arg2: bytearray | None = None, |
| | ) -> _ReturnValue: |
| | """Performs an I/O loop between incoming/outgoing and the socket.""" |
| | should_loop = True |
| | ret = None |
| |
|
| | while should_loop: |
| | errno = None |
| | try: |
| | if arg1 is None and arg2 is None: |
| | ret = func() |
| | elif arg2 is None: |
| | ret = func(arg1) |
| | else: |
| | ret = func(arg1, arg2) |
| | except ssl.SSLError as e: |
| | if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): |
| | |
| | raise e |
| | errno = e.errno |
| |
|
| | buf = self.outgoing.read() |
| | self.socket.sendall(buf) |
| |
|
| | if errno is None: |
| | should_loop = False |
| | elif errno == ssl.SSL_ERROR_WANT_READ: |
| | buf = self.socket.recv(SSL_BLOCKSIZE) |
| | if buf: |
| | self.incoming.write(buf) |
| | else: |
| | self.incoming.write_eof() |
| | return typing.cast(_ReturnValue, ret) |
| |
|