| from __future__ import annotations |
|
|
| import io |
| import socket |
| import ssl |
| import typing |
|
|
| from ..exceptions import ProxySchemeUnsupported |
|
|
| if typing.TYPE_CHECKING: |
| from typing_extensions import Self |
|
|
| from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT |
|
|
|
|
| _WriteBuffer = typing.Union[bytearray, memoryview] |
| _ReturnValue = typing.TypeVar("_ReturnValue") |
|
|
| SSL_BLOCKSIZE = 16384 |
|
|
|
|
| class SSLTransport: |
| """ |
| The SSLTransport wraps an existing socket and establishes an SSL connection. |
| |
| Contrary to Python's implementation of SSLSocket, it allows you to chain |
| multiple TLS connections together. It's particularly useful if you need to |
| implement TLS within TLS. |
| |
| The class supports most of the socket API operations. |
| """ |
|
|
| @staticmethod |
| def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: |
| """ |
| Raises a ProxySchemeUnsupported if the provided ssl_context can't be used |
| for TLS in TLS. |
| |
| The only requirement is that the ssl_context provides the 'wrap_bio' |
| methods. |
| """ |
|
|
| if not hasattr(ssl_context, "wrap_bio"): |
| raise ProxySchemeUnsupported( |
| "TLS in TLS requires SSLContext.wrap_bio() which isn't " |
| "available on non-native SSLContext" |
| ) |
|
|
| def __init__( |
| self, |
| socket: socket.socket, |
| ssl_context: ssl.SSLContext, |
| server_hostname: str | None = None, |
| suppress_ragged_eofs: bool = True, |
| ) -> None: |
| """ |
| Create an SSLTransport around socket using the provided ssl_context. |
| """ |
| self.incoming = ssl.MemoryBIO() |
| self.outgoing = ssl.MemoryBIO() |
|
|
| self.suppress_ragged_eofs = suppress_ragged_eofs |
| self.socket = socket |
|
|
| self.sslobj = ssl_context.wrap_bio( |
| self.incoming, self.outgoing, server_hostname=server_hostname |
| ) |
|
|
| |
| self._ssl_io_loop(self.sslobj.do_handshake) |
|
|
| def __enter__(self) -> Self: |
| return self |
|
|
| def __exit__(self, *_: typing.Any) -> None: |
| self.close() |
|
|
| def fileno(self) -> int: |
| return self.socket.fileno() |
|
|
| def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: |
| return self._wrap_ssl_read(len, buffer) |
|
|
| def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: |
| if flags != 0: |
| raise ValueError("non-zero flags not allowed in calls to recv") |
| return self._wrap_ssl_read(buflen) |
|
|
| def recv_into( |
| self, |
| buffer: _WriteBuffer, |
| nbytes: int | None = None, |
| flags: int = 0, |
| ) -> None | int | bytes: |
| if flags != 0: |
| raise ValueError("non-zero flags not allowed in calls to recv_into") |
| if nbytes is None: |
| nbytes = len(buffer) |
| return self.read(nbytes, buffer) |
|
|
| def sendall(self, data: bytes, flags: int = 0) -> None: |
| if flags != 0: |
| raise ValueError("non-zero flags not allowed in calls to sendall") |
| count = 0 |
| with memoryview(data) as view, view.cast("B") as byte_view: |
| amount = len(byte_view) |
| while count < amount: |
| v = self.send(byte_view[count:]) |
| count += v |
|
|
| def send(self, data: bytes, flags: int = 0) -> int: |
| if flags != 0: |
| raise ValueError("non-zero flags not allowed in calls to send") |
| return self._ssl_io_loop(self.sslobj.write, data) |
|
|
| def makefile( |
| self, |
| mode: str, |
| buffering: int | None = None, |
| *, |
| encoding: str | None = None, |
| errors: str | None = None, |
| newline: str | None = None, |
| ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: |
| """ |
| Python's httpclient uses makefile and buffered io when reading HTTP |
| messages and we need to support it. |
| |
| This is unfortunately a copy and paste of socket.py makefile with small |
| changes to point to the socket directly. |
| """ |
| if not set(mode) <= {"r", "w", "b"}: |
| raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") |
|
|
| writing = "w" in mode |
| reading = "r" in mode or not writing |
| assert reading or writing |
| binary = "b" in mode |
| rawmode = "" |
| if reading: |
| rawmode += "r" |
| if writing: |
| rawmode += "w" |
| raw = socket.SocketIO(self, rawmode) |
| self.socket._io_refs += 1 |
| if buffering is None: |
| buffering = -1 |
| if buffering < 0: |
| buffering = io.DEFAULT_BUFFER_SIZE |
| if buffering == 0: |
| if not binary: |
| raise ValueError("unbuffered streams must be binary") |
| return raw |
| buffer: typing.BinaryIO |
| if reading and writing: |
| buffer = io.BufferedRWPair(raw, raw, buffering) |
| elif reading: |
| buffer = io.BufferedReader(raw, buffering) |
| else: |
| assert writing |
| buffer = io.BufferedWriter(raw, buffering) |
| if binary: |
| return buffer |
| text = io.TextIOWrapper(buffer, encoding, errors, newline) |
| text.mode = mode |
| return text |
|
|
| def unwrap(self) -> None: |
| self._ssl_io_loop(self.sslobj.unwrap) |
|
|
| def close(self) -> None: |
| self.socket.close() |
|
|
| @typing.overload |
| def getpeercert( |
| self, binary_form: typing.Literal[False] = ... |
| ) -> _TYPE_PEER_CERT_RET_DICT | None: ... |
|
|
| @typing.overload |
| def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ... |
|
|
| def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: |
| return self.sslobj.getpeercert(binary_form) |
|
|
| def version(self) -> str | None: |
| return self.sslobj.version() |
|
|
| def cipher(self) -> tuple[str, str, int] | None: |
| return self.sslobj.cipher() |
|
|
| def selected_alpn_protocol(self) -> str | None: |
| return self.sslobj.selected_alpn_protocol() |
|
|
| def shared_ciphers(self) -> list[tuple[str, str, int]] | None: |
| return self.sslobj.shared_ciphers() |
|
|
| def compression(self) -> str | None: |
| return self.sslobj.compression() |
|
|
| def settimeout(self, value: float | None) -> None: |
| self.socket.settimeout(value) |
|
|
| def gettimeout(self) -> float | None: |
| return self.socket.gettimeout() |
|
|
| def _decref_socketios(self) -> None: |
| self.socket._decref_socketios() |
|
|
| def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: |
| try: |
| return self._ssl_io_loop(self.sslobj.read, len, buffer) |
| except ssl.SSLError as e: |
| if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: |
| return 0 |
| else: |
| raise |
|
|
| |
| @typing.overload |
| def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ... |
|
|
| |
| @typing.overload |
| def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ... |
|
|
| |
| @typing.overload |
| def _ssl_io_loop( |
| self, |
| func: typing.Callable[[int, bytearray | None], bytes], |
| arg1: int, |
| arg2: bytearray | None, |
| ) -> bytes: ... |
|
|
| def _ssl_io_loop( |
| self, |
| func: typing.Callable[..., _ReturnValue], |
| arg1: None | bytes | int = None, |
| arg2: bytearray | None = None, |
| ) -> _ReturnValue: |
| """Performs an I/O loop between incoming/outgoing and the socket.""" |
| should_loop = True |
| ret = None |
|
|
| while should_loop: |
| errno = None |
| try: |
| if arg1 is None and arg2 is None: |
| ret = func() |
| elif arg2 is None: |
| ret = func(arg1) |
| else: |
| ret = func(arg1, arg2) |
| except ssl.SSLError as e: |
| if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): |
| |
| raise e |
| errno = e.errno |
|
|
| buf = self.outgoing.read() |
| self.socket.sendall(buf) |
|
|
| if errno is None: |
| should_loop = False |
| elif errno == ssl.SSL_ERROR_WANT_READ: |
| buf = self.socket.recv(SSL_BLOCKSIZE) |
| if buf: |
| self.incoming.write(buf) |
| else: |
| self.incoming.write_eof() |
| return typing.cast(_ReturnValue, ret) |
|
|