Spaces:
Runtime error
Runtime error
| import collections | |
| import enum | |
| import warnings | |
| try: | |
| import ssl | |
| except ImportError: # pragma: no cover | |
| ssl = None | |
| from . import constants | |
| from . import exceptions | |
| from . import protocols | |
| from . import transports | |
| from .log import logger | |
| if ssl is not None: | |
| SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) | |
| class SSLProtocolState(enum.Enum): | |
| UNWRAPPED = "UNWRAPPED" | |
| DO_HANDSHAKE = "DO_HANDSHAKE" | |
| WRAPPED = "WRAPPED" | |
| FLUSHING = "FLUSHING" | |
| SHUTDOWN = "SHUTDOWN" | |
| class AppProtocolState(enum.Enum): | |
| # This tracks the state of app protocol (https://git.io/fj59P): | |
| # | |
| # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST | |
| # | |
| # * cm: connection_made() | |
| # * dr: data_received() | |
| # * er: eof_received() | |
| # * cl: connection_lost() | |
| STATE_INIT = "STATE_INIT" | |
| STATE_CON_MADE = "STATE_CON_MADE" | |
| STATE_EOF = "STATE_EOF" | |
| STATE_CON_LOST = "STATE_CON_LOST" | |
| def _create_transport_context(server_side, server_hostname): | |
| if server_side: | |
| raise ValueError('Server side SSL needs a valid SSLContext') | |
| # Client side may pass ssl=True to use a default | |
| # context; in that case the sslcontext passed is None. | |
| # The default is secure for client connections. | |
| # Python 3.4+: use up-to-date strong settings. | |
| sslcontext = ssl.create_default_context() | |
| if not server_hostname: | |
| sslcontext.check_hostname = False | |
| return sslcontext | |
| def add_flowcontrol_defaults(high, low, kb): | |
| if high is None: | |
| if low is None: | |
| hi = kb * 1024 | |
| else: | |
| lo = low | |
| hi = 4 * lo | |
| else: | |
| hi = high | |
| if low is None: | |
| lo = hi // 4 | |
| else: | |
| lo = low | |
| if not hi >= lo >= 0: | |
| raise ValueError('high (%r) must be >= low (%r) must be >= 0' % | |
| (hi, lo)) | |
| return hi, lo | |
| class _SSLProtocolTransport(transports._FlowControlMixin, | |
| transports.Transport): | |
| _start_tls_compatible = True | |
| _sendfile_compatible = constants._SendfileMode.FALLBACK | |
| def __init__(self, loop, ssl_protocol): | |
| self._loop = loop | |
| self._ssl_protocol = ssl_protocol | |
| self._closed = False | |
| def get_extra_info(self, name, default=None): | |
| """Get optional transport information.""" | |
| return self._ssl_protocol._get_extra_info(name, default) | |
| def set_protocol(self, protocol): | |
| self._ssl_protocol._set_app_protocol(protocol) | |
| def get_protocol(self): | |
| return self._ssl_protocol._app_protocol | |
| def is_closing(self): | |
| return self._closed | |
| def close(self): | |
| """Close the transport. | |
| Buffered data will be flushed asynchronously. No more data | |
| will be received. After all buffered data is flushed, the | |
| protocol's connection_lost() method will (eventually) called | |
| with None as its argument. | |
| """ | |
| if not self._closed: | |
| self._closed = True | |
| self._ssl_protocol._start_shutdown() | |
| else: | |
| self._ssl_protocol = None | |
| def __del__(self, _warnings=warnings): | |
| if not self._closed: | |
| self._closed = True | |
| _warnings.warn( | |
| "unclosed transport <asyncio._SSLProtocolTransport " | |
| "object>", ResourceWarning) | |
| def is_reading(self): | |
| return not self._ssl_protocol._app_reading_paused | |
| def pause_reading(self): | |
| """Pause the receiving end. | |
| No data will be passed to the protocol's data_received() | |
| method until resume_reading() is called. | |
| """ | |
| self._ssl_protocol._pause_reading() | |
| def resume_reading(self): | |
| """Resume the receiving end. | |
| Data received will once again be passed to the protocol's | |
| data_received() method. | |
| """ | |
| self._ssl_protocol._resume_reading() | |
| def set_write_buffer_limits(self, high=None, low=None): | |
| """Set the high- and low-water limits for write flow control. | |
| These two values control when to call the protocol's | |
| pause_writing() and resume_writing() methods. If specified, | |
| the low-water limit must be less than or equal to the | |
| high-water limit. Neither value can be negative. | |
| The defaults are implementation-specific. If only the | |
| high-water limit is given, the low-water limit defaults to an | |
| implementation-specific value less than or equal to the | |
| high-water limit. Setting high to zero forces low to zero as | |
| well, and causes pause_writing() to be called whenever the | |
| buffer becomes non-empty. Setting low to zero causes | |
| resume_writing() to be called only once the buffer is empty. | |
| Use of zero for either limit is generally sub-optimal as it | |
| reduces opportunities for doing I/O and computation | |
| concurrently. | |
| """ | |
| self._ssl_protocol._set_write_buffer_limits(high, low) | |
| self._ssl_protocol._control_app_writing() | |
| def get_write_buffer_limits(self): | |
| return (self._ssl_protocol._outgoing_low_water, | |
| self._ssl_protocol._outgoing_high_water) | |
| def get_write_buffer_size(self): | |
| """Return the current size of the write buffers.""" | |
| return self._ssl_protocol._get_write_buffer_size() | |
| def set_read_buffer_limits(self, high=None, low=None): | |
| """Set the high- and low-water limits for read flow control. | |
| These two values control when to call the upstream transport's | |
| pause_reading() and resume_reading() methods. If specified, | |
| the low-water limit must be less than or equal to the | |
| high-water limit. Neither value can be negative. | |
| The defaults are implementation-specific. If only the | |
| high-water limit is given, the low-water limit defaults to an | |
| implementation-specific value less than or equal to the | |
| high-water limit. Setting high to zero forces low to zero as | |
| well, and causes pause_reading() to be called whenever the | |
| buffer becomes non-empty. Setting low to zero causes | |
| resume_reading() to be called only once the buffer is empty. | |
| Use of zero for either limit is generally sub-optimal as it | |
| reduces opportunities for doing I/O and computation | |
| concurrently. | |
| """ | |
| self._ssl_protocol._set_read_buffer_limits(high, low) | |
| self._ssl_protocol._control_ssl_reading() | |
| def get_read_buffer_limits(self): | |
| return (self._ssl_protocol._incoming_low_water, | |
| self._ssl_protocol._incoming_high_water) | |
| def get_read_buffer_size(self): | |
| """Return the current size of the read buffer.""" | |
| return self._ssl_protocol._get_read_buffer_size() | |
| def _protocol_paused(self): | |
| # Required for sendfile fallback pause_writing/resume_writing logic | |
| return self._ssl_protocol._app_writing_paused | |
| def write(self, data): | |
| """Write some data bytes to the transport. | |
| This does not block; it buffers the data and arranges for it | |
| to be sent out asynchronously. | |
| """ | |
| if not isinstance(data, (bytes, bytearray, memoryview)): | |
| raise TypeError(f"data: expecting a bytes-like instance, " | |
| f"got {type(data).__name__}") | |
| if not data: | |
| return | |
| self._ssl_protocol._write_appdata((data,)) | |
| def writelines(self, list_of_data): | |
| """Write a list (or any iterable) of data bytes to the transport. | |
| The default implementation concatenates the arguments and | |
| calls write() on the result. | |
| """ | |
| self._ssl_protocol._write_appdata(list_of_data) | |
| def write_eof(self): | |
| """Close the write end after flushing buffered data. | |
| This raises :exc:`NotImplementedError` right now. | |
| """ | |
| raise NotImplementedError | |
| def can_write_eof(self): | |
| """Return True if this transport supports write_eof(), False if not.""" | |
| return False | |
| def abort(self): | |
| """Close the transport immediately. | |
| Buffered data will be lost. No more data will be received. | |
| The protocol's connection_lost() method will (eventually) be | |
| called with None as its argument. | |
| """ | |
| self._closed = True | |
| if self._ssl_protocol is not None: | |
| self._ssl_protocol._abort() | |
| def _force_close(self, exc): | |
| self._closed = True | |
| self._ssl_protocol._abort(exc) | |
| def _test__append_write_backlog(self, data): | |
| # for test only | |
| self._ssl_protocol._write_backlog.append(data) | |
| self._ssl_protocol._write_buffer_size += len(data) | |
| class SSLProtocol(protocols.BufferedProtocol): | |
| max_size = 256 * 1024 # Buffer size passed to read() | |
| _handshake_start_time = None | |
| _handshake_timeout_handle = None | |
| _shutdown_timeout_handle = None | |
| def __init__(self, loop, app_protocol, sslcontext, waiter, | |
| server_side=False, server_hostname=None, | |
| call_connection_made=True, | |
| ssl_handshake_timeout=None, | |
| ssl_shutdown_timeout=None): | |
| if ssl is None: | |
| raise RuntimeError("stdlib ssl module not available") | |
| self._ssl_buffer = bytearray(self.max_size) | |
| self._ssl_buffer_view = memoryview(self._ssl_buffer) | |
| if ssl_handshake_timeout is None: | |
| ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT | |
| elif ssl_handshake_timeout <= 0: | |
| raise ValueError( | |
| f"ssl_handshake_timeout should be a positive number, " | |
| f"got {ssl_handshake_timeout}") | |
| if ssl_shutdown_timeout is None: | |
| ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT | |
| elif ssl_shutdown_timeout <= 0: | |
| raise ValueError( | |
| f"ssl_shutdown_timeout should be a positive number, " | |
| f"got {ssl_shutdown_timeout}") | |
| if not sslcontext: | |
| sslcontext = _create_transport_context( | |
| server_side, server_hostname) | |
| self._server_side = server_side | |
| if server_hostname and not server_side: | |
| self._server_hostname = server_hostname | |
| else: | |
| self._server_hostname = None | |
| self._sslcontext = sslcontext | |
| # SSL-specific extra info. More info are set when the handshake | |
| # completes. | |
| self._extra = dict(sslcontext=sslcontext) | |
| # App data write buffering | |
| self._write_backlog = collections.deque() | |
| self._write_buffer_size = 0 | |
| self._waiter = waiter | |
| self._loop = loop | |
| self._set_app_protocol(app_protocol) | |
| self._app_transport = None | |
| self._app_transport_created = False | |
| # transport, ex: SelectorSocketTransport | |
| self._transport = None | |
| self._ssl_handshake_timeout = ssl_handshake_timeout | |
| self._ssl_shutdown_timeout = ssl_shutdown_timeout | |
| # SSL and state machine | |
| self._incoming = ssl.MemoryBIO() | |
| self._outgoing = ssl.MemoryBIO() | |
| self._state = SSLProtocolState.UNWRAPPED | |
| self._conn_lost = 0 # Set when connection_lost called | |
| if call_connection_made: | |
| self._app_state = AppProtocolState.STATE_INIT | |
| else: | |
| self._app_state = AppProtocolState.STATE_CON_MADE | |
| self._sslobj = self._sslcontext.wrap_bio( | |
| self._incoming, self._outgoing, | |
| server_side=self._server_side, | |
| server_hostname=self._server_hostname) | |
| # Flow Control | |
| self._ssl_writing_paused = False | |
| self._app_reading_paused = False | |
| self._ssl_reading_paused = False | |
| self._incoming_high_water = 0 | |
| self._incoming_low_water = 0 | |
| self._set_read_buffer_limits() | |
| self._eof_received = False | |
| self._app_writing_paused = False | |
| self._outgoing_high_water = 0 | |
| self._outgoing_low_water = 0 | |
| self._set_write_buffer_limits() | |
| self._get_app_transport() | |
| def _set_app_protocol(self, app_protocol): | |
| self._app_protocol = app_protocol | |
| # Make fast hasattr check first | |
| if (hasattr(app_protocol, 'get_buffer') and | |
| isinstance(app_protocol, protocols.BufferedProtocol)): | |
| self._app_protocol_get_buffer = app_protocol.get_buffer | |
| self._app_protocol_buffer_updated = app_protocol.buffer_updated | |
| self._app_protocol_is_buffer = True | |
| else: | |
| self._app_protocol_is_buffer = False | |
| def _wakeup_waiter(self, exc=None): | |
| if self._waiter is None: | |
| return | |
| if not self._waiter.cancelled(): | |
| if exc is not None: | |
| self._waiter.set_exception(exc) | |
| else: | |
| self._waiter.set_result(None) | |
| self._waiter = None | |
| def _get_app_transport(self): | |
| if self._app_transport is None: | |
| if self._app_transport_created: | |
| raise RuntimeError('Creating _SSLProtocolTransport twice') | |
| self._app_transport = _SSLProtocolTransport(self._loop, self) | |
| self._app_transport_created = True | |
| return self._app_transport | |
| def connection_made(self, transport): | |
| """Called when the low-level connection is made. | |
| Start the SSL handshake. | |
| """ | |
| self._transport = transport | |
| self._start_handshake() | |
| def connection_lost(self, exc): | |
| """Called when the low-level connection is lost or closed. | |
| The argument is an exception object or None (the latter | |
| meaning a regular EOF is received or the connection was | |
| aborted or closed). | |
| """ | |
| self._write_backlog.clear() | |
| self._outgoing.read() | |
| self._conn_lost += 1 | |
| # Just mark the app transport as closed so that its __dealloc__ | |
| # doesn't complain. | |
| if self._app_transport is not None: | |
| self._app_transport._closed = True | |
| if self._state != SSLProtocolState.DO_HANDSHAKE: | |
| if ( | |
| self._app_state == AppProtocolState.STATE_CON_MADE or | |
| self._app_state == AppProtocolState.STATE_EOF | |
| ): | |
| self._app_state = AppProtocolState.STATE_CON_LOST | |
| self._loop.call_soon(self._app_protocol.connection_lost, exc) | |
| self._set_state(SSLProtocolState.UNWRAPPED) | |
| self._transport = None | |
| self._app_transport = None | |
| self._app_protocol = None | |
| self._wakeup_waiter(exc) | |
| if self._shutdown_timeout_handle: | |
| self._shutdown_timeout_handle.cancel() | |
| self._shutdown_timeout_handle = None | |
| if self._handshake_timeout_handle: | |
| self._handshake_timeout_handle.cancel() | |
| self._handshake_timeout_handle = None | |
| def get_buffer(self, n): | |
| want = n | |
| if want <= 0 or want > self.max_size: | |
| want = self.max_size | |
| if len(self._ssl_buffer) < want: | |
| self._ssl_buffer = bytearray(want) | |
| self._ssl_buffer_view = memoryview(self._ssl_buffer) | |
| return self._ssl_buffer_view | |
| def buffer_updated(self, nbytes): | |
| self._incoming.write(self._ssl_buffer_view[:nbytes]) | |
| if self._state == SSLProtocolState.DO_HANDSHAKE: | |
| self._do_handshake() | |
| elif self._state == SSLProtocolState.WRAPPED: | |
| self._do_read() | |
| elif self._state == SSLProtocolState.FLUSHING: | |
| self._do_flush() | |
| elif self._state == SSLProtocolState.SHUTDOWN: | |
| self._do_shutdown() | |
| def eof_received(self): | |
| """Called when the other end of the low-level stream | |
| is half-closed. | |
| If this returns a false value (including None), the transport | |
| will close itself. If it returns a true value, closing the | |
| transport is up to the protocol. | |
| """ | |
| self._eof_received = True | |
| try: | |
| if self._loop.get_debug(): | |
| logger.debug("%r received EOF", self) | |
| if self._state == SSLProtocolState.DO_HANDSHAKE: | |
| self._on_handshake_complete(ConnectionResetError) | |
| elif self._state == SSLProtocolState.WRAPPED: | |
| self._set_state(SSLProtocolState.FLUSHING) | |
| if self._app_reading_paused: | |
| return True | |
| else: | |
| self._do_flush() | |
| elif self._state == SSLProtocolState.FLUSHING: | |
| self._do_write() | |
| self._set_state(SSLProtocolState.SHUTDOWN) | |
| self._do_shutdown() | |
| elif self._state == SSLProtocolState.SHUTDOWN: | |
| self._do_shutdown() | |
| except Exception: | |
| self._transport.close() | |
| raise | |
| def _get_extra_info(self, name, default=None): | |
| if name in self._extra: | |
| return self._extra[name] | |
| elif self._transport is not None: | |
| return self._transport.get_extra_info(name, default) | |
| else: | |
| return default | |
| def _set_state(self, new_state): | |
| allowed = False | |
| if new_state == SSLProtocolState.UNWRAPPED: | |
| allowed = True | |
| elif ( | |
| self._state == SSLProtocolState.UNWRAPPED and | |
| new_state == SSLProtocolState.DO_HANDSHAKE | |
| ): | |
| allowed = True | |
| elif ( | |
| self._state == SSLProtocolState.DO_HANDSHAKE and | |
| new_state == SSLProtocolState.WRAPPED | |
| ): | |
| allowed = True | |
| elif ( | |
| self._state == SSLProtocolState.WRAPPED and | |
| new_state == SSLProtocolState.FLUSHING | |
| ): | |
| allowed = True | |
| elif ( | |
| self._state == SSLProtocolState.FLUSHING and | |
| new_state == SSLProtocolState.SHUTDOWN | |
| ): | |
| allowed = True | |
| if allowed: | |
| self._state = new_state | |
| else: | |
| raise RuntimeError( | |
| 'cannot switch state from {} to {}'.format( | |
| self._state, new_state)) | |
| # Handshake flow | |
| def _start_handshake(self): | |
| if self._loop.get_debug(): | |
| logger.debug("%r starts SSL handshake", self) | |
| self._handshake_start_time = self._loop.time() | |
| else: | |
| self._handshake_start_time = None | |
| self._set_state(SSLProtocolState.DO_HANDSHAKE) | |
| # start handshake timeout count down | |
| self._handshake_timeout_handle = \ | |
| self._loop.call_later(self._ssl_handshake_timeout, | |
| lambda: self._check_handshake_timeout()) | |
| self._do_handshake() | |
| def _check_handshake_timeout(self): | |
| if self._state == SSLProtocolState.DO_HANDSHAKE: | |
| msg = ( | |
| f"SSL handshake is taking longer than " | |
| f"{self._ssl_handshake_timeout} seconds: " | |
| f"aborting the connection" | |
| ) | |
| self._fatal_error(ConnectionAbortedError(msg)) | |
| def _do_handshake(self): | |
| try: | |
| self._sslobj.do_handshake() | |
| except SSLAgainErrors: | |
| self._process_outgoing() | |
| except ssl.SSLError as exc: | |
| self._on_handshake_complete(exc) | |
| else: | |
| self._on_handshake_complete(None) | |
| def _on_handshake_complete(self, handshake_exc): | |
| if self._handshake_timeout_handle is not None: | |
| self._handshake_timeout_handle.cancel() | |
| self._handshake_timeout_handle = None | |
| sslobj = self._sslobj | |
| try: | |
| if handshake_exc is None: | |
| self._set_state(SSLProtocolState.WRAPPED) | |
| else: | |
| raise handshake_exc | |
| peercert = sslobj.getpeercert() | |
| except Exception as exc: | |
| self._set_state(SSLProtocolState.UNWRAPPED) | |
| if isinstance(exc, ssl.CertificateError): | |
| msg = 'SSL handshake failed on verifying the certificate' | |
| else: | |
| msg = 'SSL handshake failed' | |
| self._fatal_error(exc, msg) | |
| self._wakeup_waiter(exc) | |
| return | |
| if self._loop.get_debug(): | |
| dt = self._loop.time() - self._handshake_start_time | |
| logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) | |
| # Add extra info that becomes available after handshake. | |
| self._extra.update(peercert=peercert, | |
| cipher=sslobj.cipher(), | |
| compression=sslobj.compression(), | |
| ssl_object=sslobj) | |
| if self._app_state == AppProtocolState.STATE_INIT: | |
| self._app_state = AppProtocolState.STATE_CON_MADE | |
| self._app_protocol.connection_made(self._get_app_transport()) | |
| self._wakeup_waiter() | |
| self._do_read() | |
| # Shutdown flow | |
| def _start_shutdown(self): | |
| if ( | |
| self._state in ( | |
| SSLProtocolState.FLUSHING, | |
| SSLProtocolState.SHUTDOWN, | |
| SSLProtocolState.UNWRAPPED | |
| ) | |
| ): | |
| return | |
| if self._app_transport is not None: | |
| self._app_transport._closed = True | |
| if self._state == SSLProtocolState.DO_HANDSHAKE: | |
| self._abort() | |
| else: | |
| self._set_state(SSLProtocolState.FLUSHING) | |
| self._shutdown_timeout_handle = self._loop.call_later( | |
| self._ssl_shutdown_timeout, | |
| lambda: self._check_shutdown_timeout() | |
| ) | |
| self._do_flush() | |
| def _check_shutdown_timeout(self): | |
| if ( | |
| self._state in ( | |
| SSLProtocolState.FLUSHING, | |
| SSLProtocolState.SHUTDOWN | |
| ) | |
| ): | |
| self._transport._force_close( | |
| exceptions.TimeoutError('SSL shutdown timed out')) | |
| def _do_flush(self): | |
| self._do_read() | |
| self._set_state(SSLProtocolState.SHUTDOWN) | |
| self._do_shutdown() | |
| def _do_shutdown(self): | |
| try: | |
| if not self._eof_received: | |
| self._sslobj.unwrap() | |
| except SSLAgainErrors: | |
| self._process_outgoing() | |
| except ssl.SSLError as exc: | |
| self._on_shutdown_complete(exc) | |
| else: | |
| self._process_outgoing() | |
| self._call_eof_received() | |
| self._on_shutdown_complete(None) | |
| def _on_shutdown_complete(self, shutdown_exc): | |
| if self._shutdown_timeout_handle is not None: | |
| self._shutdown_timeout_handle.cancel() | |
| self._shutdown_timeout_handle = None | |
| if shutdown_exc: | |
| self._fatal_error(shutdown_exc) | |
| else: | |
| self._loop.call_soon(self._transport.close) | |
| def _abort(self): | |
| self._set_state(SSLProtocolState.UNWRAPPED) | |
| if self._transport is not None: | |
| self._transport.abort() | |
| # Outgoing flow | |
| def _write_appdata(self, list_of_data): | |
| if ( | |
| self._state in ( | |
| SSLProtocolState.FLUSHING, | |
| SSLProtocolState.SHUTDOWN, | |
| SSLProtocolState.UNWRAPPED | |
| ) | |
| ): | |
| if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: | |
| logger.warning('SSL connection is closed') | |
| self._conn_lost += 1 | |
| return | |
| for data in list_of_data: | |
| self._write_backlog.append(data) | |
| self._write_buffer_size += len(data) | |
| try: | |
| if self._state == SSLProtocolState.WRAPPED: | |
| self._do_write() | |
| except Exception as ex: | |
| self._fatal_error(ex, 'Fatal error on SSL protocol') | |
| def _do_write(self): | |
| try: | |
| while self._write_backlog: | |
| data = self._write_backlog[0] | |
| count = self._sslobj.write(data) | |
| data_len = len(data) | |
| if count < data_len: | |
| self._write_backlog[0] = data[count:] | |
| self._write_buffer_size -= count | |
| else: | |
| del self._write_backlog[0] | |
| self._write_buffer_size -= data_len | |
| except SSLAgainErrors: | |
| pass | |
| self._process_outgoing() | |
| def _process_outgoing(self): | |
| if not self._ssl_writing_paused: | |
| data = self._outgoing.read() | |
| if len(data): | |
| self._transport.write(data) | |
| self._control_app_writing() | |
| # Incoming flow | |
| def _do_read(self): | |
| if ( | |
| self._state not in ( | |
| SSLProtocolState.WRAPPED, | |
| SSLProtocolState.FLUSHING, | |
| ) | |
| ): | |
| return | |
| try: | |
| if not self._app_reading_paused: | |
| if self._app_protocol_is_buffer: | |
| self._do_read__buffered() | |
| else: | |
| self._do_read__copied() | |
| if self._write_backlog: | |
| self._do_write() | |
| else: | |
| self._process_outgoing() | |
| self._control_ssl_reading() | |
| except Exception as ex: | |
| self._fatal_error(ex, 'Fatal error on SSL protocol') | |
| def _do_read__buffered(self): | |
| offset = 0 | |
| count = 1 | |
| buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) | |
| wants = len(buf) | |
| try: | |
| count = self._sslobj.read(wants, buf) | |
| if count > 0: | |
| offset = count | |
| while offset < wants: | |
| count = self._sslobj.read(wants - offset, buf[offset:]) | |
| if count > 0: | |
| offset += count | |
| else: | |
| break | |
| else: | |
| self._loop.call_soon(lambda: self._do_read()) | |
| except SSLAgainErrors: | |
| pass | |
| if offset > 0: | |
| self._app_protocol_buffer_updated(offset) | |
| if not count: | |
| # close_notify | |
| self._call_eof_received() | |
| self._start_shutdown() | |
| def _do_read__copied(self): | |
| chunk = b'1' | |
| zero = True | |
| one = False | |
| try: | |
| while True: | |
| chunk = self._sslobj.read(self.max_size) | |
| if not chunk: | |
| break | |
| if zero: | |
| zero = False | |
| one = True | |
| first = chunk | |
| elif one: | |
| one = False | |
| data = [first, chunk] | |
| else: | |
| data.append(chunk) | |
| except SSLAgainErrors: | |
| pass | |
| if one: | |
| self._app_protocol.data_received(first) | |
| elif not zero: | |
| self._app_protocol.data_received(b''.join(data)) | |
| if not chunk: | |
| # close_notify | |
| self._call_eof_received() | |
| self._start_shutdown() | |
| def _call_eof_received(self): | |
| try: | |
| if self._app_state == AppProtocolState.STATE_CON_MADE: | |
| self._app_state = AppProtocolState.STATE_EOF | |
| keep_open = self._app_protocol.eof_received() | |
| if keep_open: | |
| logger.warning('returning true from eof_received() ' | |
| 'has no effect when using ssl') | |
| except (KeyboardInterrupt, SystemExit): | |
| raise | |
| except BaseException as ex: | |
| self._fatal_error(ex, 'Error calling eof_received()') | |
| # Flow control for writes from APP socket | |
| def _control_app_writing(self): | |
| size = self._get_write_buffer_size() | |
| if size >= self._outgoing_high_water and not self._app_writing_paused: | |
| self._app_writing_paused = True | |
| try: | |
| self._app_protocol.pause_writing() | |
| except (KeyboardInterrupt, SystemExit): | |
| raise | |
| except BaseException as exc: | |
| self._loop.call_exception_handler({ | |
| 'message': 'protocol.pause_writing() failed', | |
| 'exception': exc, | |
| 'transport': self._app_transport, | |
| 'protocol': self, | |
| }) | |
| elif size <= self._outgoing_low_water and self._app_writing_paused: | |
| self._app_writing_paused = False | |
| try: | |
| self._app_protocol.resume_writing() | |
| except (KeyboardInterrupt, SystemExit): | |
| raise | |
| except BaseException as exc: | |
| self._loop.call_exception_handler({ | |
| 'message': 'protocol.resume_writing() failed', | |
| 'exception': exc, | |
| 'transport': self._app_transport, | |
| 'protocol': self, | |
| }) | |
| def _get_write_buffer_size(self): | |
| return self._outgoing.pending + self._write_buffer_size | |
| def _set_write_buffer_limits(self, high=None, low=None): | |
| high, low = add_flowcontrol_defaults( | |
| high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE) | |
| self._outgoing_high_water = high | |
| self._outgoing_low_water = low | |
| # Flow control for reads to APP socket | |
| def _pause_reading(self): | |
| self._app_reading_paused = True | |
| def _resume_reading(self): | |
| if self._app_reading_paused: | |
| self._app_reading_paused = False | |
| def resume(): | |
| if self._state == SSLProtocolState.WRAPPED: | |
| self._do_read() | |
| elif self._state == SSLProtocolState.FLUSHING: | |
| self._do_flush() | |
| elif self._state == SSLProtocolState.SHUTDOWN: | |
| self._do_shutdown() | |
| self._loop.call_soon(resume) | |
| # Flow control for reads from SSL socket | |
| def _control_ssl_reading(self): | |
| size = self._get_read_buffer_size() | |
| if size >= self._incoming_high_water and not self._ssl_reading_paused: | |
| self._ssl_reading_paused = True | |
| self._transport.pause_reading() | |
| elif size <= self._incoming_low_water and self._ssl_reading_paused: | |
| self._ssl_reading_paused = False | |
| self._transport.resume_reading() | |
| def _set_read_buffer_limits(self, high=None, low=None): | |
| high, low = add_flowcontrol_defaults( | |
| high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ) | |
| self._incoming_high_water = high | |
| self._incoming_low_water = low | |
| def _get_read_buffer_size(self): | |
| return self._incoming.pending | |
| # Flow control for writes to SSL socket | |
| def pause_writing(self): | |
| """Called when the low-level transport's buffer goes over | |
| the high-water mark. | |
| """ | |
| assert not self._ssl_writing_paused | |
| self._ssl_writing_paused = True | |
| def resume_writing(self): | |
| """Called when the low-level transport's buffer drains below | |
| the low-water mark. | |
| """ | |
| assert self._ssl_writing_paused | |
| self._ssl_writing_paused = False | |
| self._process_outgoing() | |
| def _fatal_error(self, exc, message='Fatal error on transport'): | |
| if self._transport: | |
| self._transport._force_close(exc) | |
| if isinstance(exc, OSError): | |
| if self._loop.get_debug(): | |
| logger.debug("%r: %s", self, message, exc_info=True) | |
| elif not isinstance(exc, exceptions.CancelledError): | |
| self._loop.call_exception_handler({ | |
| 'message': message, | |
| 'exception': exc, | |
| 'transport': self._transport, | |
| 'protocol': self, | |
| }) | |