| | from __future__ import annotations |
| |
|
| | __all__ = ( |
| | "TLSAttribute", |
| | "TLSConnectable", |
| | "TLSListener", |
| | "TLSStream", |
| | ) |
| |
|
| | import logging |
| | import re |
| | import ssl |
| | import sys |
| | from collections.abc import Callable, Mapping |
| | from dataclasses import dataclass |
| | from functools import wraps |
| | from ssl import SSLContext |
| | from typing import Any, TypeVar |
| |
|
| | from .. import ( |
| | BrokenResourceError, |
| | EndOfStream, |
| | aclose_forcefully, |
| | get_cancelled_exc_class, |
| | to_thread, |
| | ) |
| | from .._core._typedattr import TypedAttributeSet, typed_attribute |
| | from ..abc import ( |
| | AnyByteStream, |
| | AnyByteStreamConnectable, |
| | ByteStream, |
| | ByteStreamConnectable, |
| | Listener, |
| | TaskGroup, |
| | ) |
| |
|
| | if sys.version_info >= (3, 10): |
| | from typing import TypeAlias |
| | else: |
| | from typing_extensions import TypeAlias |
| |
|
| | if sys.version_info >= (3, 11): |
| | from typing import TypeVarTuple, Unpack |
| | else: |
| | from typing_extensions import TypeVarTuple, Unpack |
| |
|
| | if sys.version_info >= (3, 12): |
| | from typing import override |
| | else: |
| | from typing_extensions import override |
| |
|
| | T_Retval = TypeVar("T_Retval") |
| | PosArgsT = TypeVarTuple("PosArgsT") |
| | _PCTRTT: TypeAlias = tuple[tuple[str, str], ...] |
| | _PCTRTTT: TypeAlias = tuple[_PCTRTT, ...] |
| |
|
| |
|
| | class TLSAttribute(TypedAttributeSet): |
| | """Contains Transport Layer Security related attributes.""" |
| |
|
| | |
| | alpn_protocol: str | None = typed_attribute() |
| | |
| | channel_binding_tls_unique: bytes = typed_attribute() |
| | |
| | cipher: tuple[str, str, int] = typed_attribute() |
| | |
| | |
| | peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() |
| | |
| | peer_certificate_binary: bytes | None = typed_attribute() |
| | |
| | server_side: bool = typed_attribute() |
| | |
| | |
| | shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() |
| | |
| | ssl_object: ssl.SSLObject = typed_attribute() |
| | |
| | |
| | standard_compatible: bool = typed_attribute() |
| | |
| | tls_version: str = typed_attribute() |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class TLSStream(ByteStream): |
| | """ |
| | A stream wrapper that encrypts all sent data and decrypts received data. |
| | |
| | This class has no public initializer; use :meth:`wrap` instead. |
| | All extra attributes from :class:`~TLSAttribute` are supported. |
| | |
| | :var AnyByteStream transport_stream: the wrapped stream |
| | |
| | """ |
| |
|
| | transport_stream: AnyByteStream |
| | standard_compatible: bool |
| | _ssl_object: ssl.SSLObject |
| | _read_bio: ssl.MemoryBIO |
| | _write_bio: ssl.MemoryBIO |
| |
|
| | @classmethod |
| | async def wrap( |
| | cls, |
| | transport_stream: AnyByteStream, |
| | *, |
| | server_side: bool | None = None, |
| | hostname: str | None = None, |
| | ssl_context: ssl.SSLContext | None = None, |
| | standard_compatible: bool = True, |
| | ) -> TLSStream: |
| | """ |
| | Wrap an existing stream with Transport Layer Security. |
| | |
| | This performs a TLS handshake with the peer. |
| | |
| | :param transport_stream: a bytes-transporting stream to wrap |
| | :param server_side: ``True`` if this is the server side of the connection, |
| | ``False`` if this is the client side (if omitted, will be set to ``False`` |
| | if ``hostname`` has been provided, ``False`` otherwise). Used only to create |
| | a default context when an explicit context has not been provided. |
| | :param hostname: host name of the peer (if host name checking is desired) |
| | :param ssl_context: the SSLContext object to use (if not provided, a secure |
| | default will be created) |
| | :param standard_compatible: if ``False``, skip the closing handshake when |
| | closing the connection, and don't raise an exception if the peer does the |
| | same |
| | :raises ~ssl.SSLError: if the TLS handshake fails |
| | |
| | """ |
| | if server_side is None: |
| | server_side = not hostname |
| |
|
| | if not ssl_context: |
| | purpose = ( |
| | ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH |
| | ) |
| | ssl_context = ssl.create_default_context(purpose) |
| |
|
| | |
| | if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): |
| | ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF |
| |
|
| | bio_in = ssl.MemoryBIO() |
| | bio_out = ssl.MemoryBIO() |
| |
|
| | |
| | |
| | if type(ssl_context) is ssl.SSLContext: |
| | ssl_object = ssl_context.wrap_bio( |
| | bio_in, bio_out, server_side=server_side, server_hostname=hostname |
| | ) |
| | else: |
| | ssl_object = await to_thread.run_sync( |
| | ssl_context.wrap_bio, |
| | bio_in, |
| | bio_out, |
| | server_side, |
| | hostname, |
| | None, |
| | ) |
| |
|
| | wrapper = cls( |
| | transport_stream=transport_stream, |
| | standard_compatible=standard_compatible, |
| | _ssl_object=ssl_object, |
| | _read_bio=bio_in, |
| | _write_bio=bio_out, |
| | ) |
| | await wrapper._call_sslobject_method(ssl_object.do_handshake) |
| | return wrapper |
| |
|
| | async def _call_sslobject_method( |
| | self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] |
| | ) -> T_Retval: |
| | while True: |
| | try: |
| | result = func(*args) |
| | except ssl.SSLWantReadError: |
| | try: |
| | |
| | if self._write_bio.pending: |
| | await self.transport_stream.send(self._write_bio.read()) |
| |
|
| | data = await self.transport_stream.receive() |
| | except EndOfStream: |
| | self._read_bio.write_eof() |
| | except OSError as exc: |
| | self._read_bio.write_eof() |
| | self._write_bio.write_eof() |
| | raise BrokenResourceError from exc |
| | else: |
| | self._read_bio.write(data) |
| | except ssl.SSLWantWriteError: |
| | await self.transport_stream.send(self._write_bio.read()) |
| | except ssl.SSLSyscallError as exc: |
| | self._read_bio.write_eof() |
| | self._write_bio.write_eof() |
| | raise BrokenResourceError from exc |
| | except ssl.SSLError as exc: |
| | self._read_bio.write_eof() |
| | self._write_bio.write_eof() |
| | if isinstance(exc, ssl.SSLEOFError) or ( |
| | exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror |
| | ): |
| | if self.standard_compatible: |
| | raise BrokenResourceError from exc |
| | else: |
| | raise EndOfStream from None |
| |
|
| | raise |
| | else: |
| | |
| | if self._write_bio.pending: |
| | await self.transport_stream.send(self._write_bio.read()) |
| |
|
| | return result |
| |
|
| | async def unwrap(self) -> tuple[AnyByteStream, bytes]: |
| | """ |
| | Does the TLS closing handshake. |
| | |
| | :return: a tuple of (wrapped byte stream, bytes left in the read buffer) |
| | |
| | """ |
| | await self._call_sslobject_method(self._ssl_object.unwrap) |
| | self._read_bio.write_eof() |
| | self._write_bio.write_eof() |
| | return self.transport_stream, self._read_bio.read() |
| |
|
| | async def aclose(self) -> None: |
| | if self.standard_compatible: |
| | try: |
| | await self.unwrap() |
| | except BaseException: |
| | await aclose_forcefully(self.transport_stream) |
| | raise |
| |
|
| | await self.transport_stream.aclose() |
| |
|
| | async def receive(self, max_bytes: int = 65536) -> bytes: |
| | data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) |
| | if not data: |
| | raise EndOfStream |
| |
|
| | return data |
| |
|
| | async def send(self, item: bytes) -> None: |
| | await self._call_sslobject_method(self._ssl_object.write, item) |
| |
|
| | async def send_eof(self) -> None: |
| | tls_version = self.extra(TLSAttribute.tls_version) |
| | match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) |
| | if match: |
| | major, minor = int(match.group(1)), int(match.group(2) or 0) |
| | if (major, minor) < (1, 3): |
| | raise NotImplementedError( |
| | f"send_eof() requires at least TLSv1.3; current " |
| | f"session uses {tls_version}" |
| | ) |
| |
|
| | raise NotImplementedError( |
| | "send_eof() has not yet been implemented for TLS streams" |
| | ) |
| |
|
| | @property |
| | def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: |
| | return { |
| | **self.transport_stream.extra_attributes, |
| | TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, |
| | TLSAttribute.channel_binding_tls_unique: ( |
| | self._ssl_object.get_channel_binding |
| | ), |
| | TLSAttribute.cipher: self._ssl_object.cipher, |
| | TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), |
| | TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( |
| | True |
| | ), |
| | TLSAttribute.server_side: lambda: self._ssl_object.server_side, |
| | TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() |
| | if self._ssl_object.server_side |
| | else None, |
| | TLSAttribute.standard_compatible: lambda: self.standard_compatible, |
| | TLSAttribute.ssl_object: lambda: self._ssl_object, |
| | TLSAttribute.tls_version: self._ssl_object.version, |
| | } |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class TLSListener(Listener[TLSStream]): |
| | """ |
| | A convenience listener that wraps another listener and auto-negotiates a TLS session |
| | on every accepted connection. |
| | |
| | If the TLS handshake times out or raises an exception, |
| | :meth:`handle_handshake_error` is called to do whatever post-mortem processing is |
| | deemed necessary. |
| | |
| | Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. |
| | |
| | :param Listener listener: the listener to wrap |
| | :param ssl_context: the SSL context object |
| | :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` |
| | :param handshake_timeout: time limit for the TLS handshake |
| | (passed to :func:`~anyio.fail_after`) |
| | """ |
| |
|
| | listener: Listener[Any] |
| | ssl_context: ssl.SSLContext |
| | standard_compatible: bool = True |
| | handshake_timeout: float = 30 |
| |
|
| | @staticmethod |
| | async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: |
| | """ |
| | Handle an exception raised during the TLS handshake. |
| | |
| | This method does 3 things: |
| | |
| | #. Forcefully closes the original stream |
| | #. Logs the exception (unless it was a cancellation exception) using the |
| | ``anyio.streams.tls`` logger |
| | #. Reraises the exception if it was a base exception or a cancellation exception |
| | |
| | :param exc: the exception |
| | :param stream: the original stream |
| | |
| | """ |
| | await aclose_forcefully(stream) |
| |
|
| | |
| | if not isinstance(exc, get_cancelled_exc_class()): |
| | |
| | |
| | |
| | |
| | logging.getLogger(__name__).exception( |
| | "Error during TLS handshake", exc_info=exc |
| | ) |
| |
|
| | |
| | if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): |
| | raise |
| |
|
| | async def serve( |
| | self, |
| | handler: Callable[[TLSStream], Any], |
| | task_group: TaskGroup | None = None, |
| | ) -> None: |
| | @wraps(handler) |
| | async def handler_wrapper(stream: AnyByteStream) -> None: |
| | from .. import fail_after |
| |
|
| | try: |
| | with fail_after(self.handshake_timeout): |
| | wrapped_stream = await TLSStream.wrap( |
| | stream, |
| | ssl_context=self.ssl_context, |
| | standard_compatible=self.standard_compatible, |
| | ) |
| | except BaseException as exc: |
| | await self.handle_handshake_error(exc, stream) |
| | else: |
| | await handler(wrapped_stream) |
| |
|
| | await self.listener.serve(handler_wrapper, task_group) |
| |
|
| | async def aclose(self) -> None: |
| | await self.listener.aclose() |
| |
|
| | @property |
| | def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: |
| | return { |
| | TLSAttribute.standard_compatible: lambda: self.standard_compatible, |
| | } |
| |
|
| |
|
| | class TLSConnectable(ByteStreamConnectable): |
| | """ |
| | Wraps another connectable and does TLS negotiation after a successful connection. |
| | |
| | :param connectable: the connectable to wrap |
| | :param hostname: host name of the server (if host name checking is desired) |
| | :param ssl_context: the SSLContext object to use (if not provided, a secure default |
| | will be created) |
| | :param standard_compatible: if ``False``, skip the closing handshake when closing |
| | the connection, and don't raise an exception if the server does the same |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | connectable: AnyByteStreamConnectable, |
| | *, |
| | hostname: str | None = None, |
| | ssl_context: ssl.SSLContext | None = None, |
| | standard_compatible: bool = True, |
| | ) -> None: |
| | self.connectable = connectable |
| | self.ssl_context: SSLContext = ssl_context or ssl.create_default_context( |
| | ssl.Purpose.SERVER_AUTH |
| | ) |
| | if not isinstance(self.ssl_context, ssl.SSLContext): |
| | raise TypeError( |
| | "ssl_context must be an instance of ssl.SSLContext, not " |
| | f"{type(self.ssl_context).__name__}" |
| | ) |
| | self.hostname = hostname |
| | self.standard_compatible = standard_compatible |
| |
|
| | @override |
| | async def connect(self) -> TLSStream: |
| | stream = await self.connectable.connect() |
| | try: |
| | return await TLSStream.wrap( |
| | stream, |
| | hostname=self.hostname, |
| | ssl_context=self.ssl_context, |
| | standard_compatible=self.standard_compatible, |
| | ) |
| | except BaseException: |
| | await aclose_forcefully(stream) |
| | raise |
| |
|