File size: 3,690 Bytes
4cef980 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import socket
import ssl
import sys
import typing
from .._exceptions import (
ConnectError,
ConnectTimeout,
ExceptionMapping,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import NetworkBackend, NetworkStream
class SyncStream(NetworkStream):
def __init__(self, sock: socket.socket) -> None:
self._sock = sock
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
return self._sock.recv(max_bytes)
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
if not buffer:
return
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
while buffer:
self._sock.settimeout(timeout)
n = self._sock.send(buffer)
buffer = buffer[n:]
def close(self) -> None:
self._sock.close()
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> NetworkStream:
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
except Exception as exc: # pragma: nocover
self.close()
raise exc
return SyncStream(sock)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
return self._sock._sslobj # type: ignore
if info == "client_addr":
return self._sock.getsockname()
if info == "server_addr":
return self._sock.getpeername()
if info == "socket":
return self._sock
if info == "is_readable":
return is_socket_readable(self._sock)
return None
class SyncBackend(NetworkBackend):
def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
) -> NetworkStream:
address = (host, port)
source_address = None if local_address is None else (local_address, 0)
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
sock = socket.create_connection(
address, timeout, source_address=source_address
)
return SyncStream(sock)
def connect_unix_socket(
self, path: str, timeout: typing.Optional[float] = None
) -> NetworkStream: # pragma: nocover
if sys.platform == "win32":
raise RuntimeError(
"Attempted to connect to a UNIX socket on a Windows system."
)
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect(path)
return SyncStream(sock)
|