DeepSolanaCoder
/
DeepSeek-Coder-main
/finetune
/venv
/lib
/python3.12
/site-packages
/aiohttp
/connector.py
| import asyncio | |
| import functools | |
| import random | |
| import socket | |
| import sys | |
| import traceback | |
| import warnings | |
| from collections import OrderedDict, defaultdict, deque | |
| from contextlib import suppress | |
| from http import HTTPStatus | |
| from itertools import chain, cycle, islice | |
| from time import monotonic | |
| from types import TracebackType | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Awaitable, | |
| Callable, | |
| DefaultDict, | |
| Deque, | |
| Dict, | |
| Iterator, | |
| List, | |
| Literal, | |
| Optional, | |
| Sequence, | |
| Set, | |
| Tuple, | |
| Type, | |
| Union, | |
| cast, | |
| ) | |
| import aiohappyeyeballs | |
| from . import hdrs, helpers | |
| from .abc import AbstractResolver, ResolveResult | |
| from .client_exceptions import ( | |
| ClientConnectionError, | |
| ClientConnectorCertificateError, | |
| ClientConnectorDNSError, | |
| ClientConnectorError, | |
| ClientConnectorSSLError, | |
| ClientHttpProxyError, | |
| ClientProxyConnectionError, | |
| ServerFingerprintMismatch, | |
| UnixClientConnectorError, | |
| cert_errors, | |
| ssl_errors, | |
| ) | |
| from .client_proto import ResponseHandler | |
| from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params | |
| from .helpers import ( | |
| ceil_timeout, | |
| is_ip_address, | |
| noop, | |
| sentinel, | |
| set_exception, | |
| set_result, | |
| ) | |
| from .resolver import DefaultResolver | |
| if TYPE_CHECKING: | |
| import ssl | |
| SSLContext = ssl.SSLContext | |
| else: | |
| try: | |
| import ssl | |
| SSLContext = ssl.SSLContext | |
| except ImportError: # pragma: no cover | |
| ssl = None # type: ignore[assignment] | |
| SSLContext = object # type: ignore[misc,assignment] | |
| EMPTY_SCHEMA_SET = frozenset({""}) | |
| HTTP_SCHEMA_SET = frozenset({"http", "https"}) | |
| WS_SCHEMA_SET = frozenset({"ws", "wss"}) | |
| HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET | |
| HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET | |
| NEEDS_CLEANUP_CLOSED = (3, 13, 0) <= sys.version_info < ( | |
| 3, | |
| 13, | |
| 1, | |
| ) or sys.version_info < (3, 12, 7) | |
| # Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960 | |
| # which first appeared in Python 3.12.7 and 3.13.1 | |
| __all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") | |
| if TYPE_CHECKING: | |
| from .client import ClientTimeout | |
| from .client_reqrep import ConnectionKey | |
| from .tracing import Trace | |
| class _DeprecationWaiter: | |
| __slots__ = ("_awaitable", "_awaited") | |
| def __init__(self, awaitable: Awaitable[Any]) -> None: | |
| self._awaitable = awaitable | |
| self._awaited = False | |
| def __await__(self) -> Any: | |
| self._awaited = True | |
| return self._awaitable.__await__() | |
| def __del__(self) -> None: | |
| if not self._awaited: | |
| warnings.warn( | |
| "Connector.close() is a coroutine, " | |
| "please use await connector.close()", | |
| DeprecationWarning, | |
| ) | |
| class Connection: | |
| _source_traceback = None | |
| def __init__( | |
| self, | |
| connector: "BaseConnector", | |
| key: "ConnectionKey", | |
| protocol: ResponseHandler, | |
| loop: asyncio.AbstractEventLoop, | |
| ) -> None: | |
| self._key = key | |
| self._connector = connector | |
| self._loop = loop | |
| self._protocol: Optional[ResponseHandler] = protocol | |
| self._callbacks: List[Callable[[], None]] = [] | |
| if loop.get_debug(): | |
| self._source_traceback = traceback.extract_stack(sys._getframe(1)) | |
| def __repr__(self) -> str: | |
| return f"Connection<{self._key}>" | |
| def __del__(self, _warnings: Any = warnings) -> None: | |
| if self._protocol is not None: | |
| kwargs = {"source": self} | |
| _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs) | |
| if self._loop.is_closed(): | |
| return | |
| self._connector._release(self._key, self._protocol, should_close=True) | |
| context = {"client_connection": self, "message": "Unclosed connection"} | |
| if self._source_traceback is not None: | |
| context["source_traceback"] = self._source_traceback | |
| self._loop.call_exception_handler(context) | |
| def __bool__(self) -> Literal[True]: | |
| """Force subclasses to not be falsy, to make checks simpler.""" | |
| return True | |
| def loop(self) -> asyncio.AbstractEventLoop: | |
| warnings.warn( | |
| "connector.loop property is deprecated", DeprecationWarning, stacklevel=2 | |
| ) | |
| return self._loop | |
| def transport(self) -> Optional[asyncio.Transport]: | |
| if self._protocol is None: | |
| return None | |
| return self._protocol.transport | |
| def protocol(self) -> Optional[ResponseHandler]: | |
| return self._protocol | |
| def add_callback(self, callback: Callable[[], None]) -> None: | |
| if callback is not None: | |
| self._callbacks.append(callback) | |
| def _notify_release(self) -> None: | |
| callbacks, self._callbacks = self._callbacks[:], [] | |
| for cb in callbacks: | |
| with suppress(Exception): | |
| cb() | |
| def close(self) -> None: | |
| self._notify_release() | |
| if self._protocol is not None: | |
| self._connector._release(self._key, self._protocol, should_close=True) | |
| self._protocol = None | |
| def release(self) -> None: | |
| self._notify_release() | |
| if self._protocol is not None: | |
| self._connector._release(self._key, self._protocol) | |
| self._protocol = None | |
| def closed(self) -> bool: | |
| return self._protocol is None or not self._protocol.is_connected() | |
| class _TransportPlaceholder: | |
| """placeholder for BaseConnector.connect function""" | |
| __slots__ = () | |
| def close(self) -> None: | |
| """Close the placeholder transport.""" | |
| class BaseConnector: | |
| """Base connector class. | |
| keepalive_timeout - (optional) Keep-alive timeout. | |
| force_close - Set to True to force close and do reconnect | |
| after each request (and between redirects). | |
| limit - The total number of simultaneous connections. | |
| limit_per_host - Number of simultaneous connections to one host. | |
| enable_cleanup_closed - Enables clean-up closed ssl transports. | |
| Disabled by default. | |
| timeout_ceil_threshold - Trigger ceiling of timeout values when | |
| it's above timeout_ceil_threshold. | |
| loop - Optional event loop. | |
| """ | |
| _closed = True # prevent AttributeError in __del__ if ctor was failed | |
| _source_traceback = None | |
| # abort transport after 2 seconds (cleanup broken connections) | |
| _cleanup_closed_period = 2.0 | |
| allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | |
| def __init__( | |
| self, | |
| *, | |
| keepalive_timeout: Union[object, None, float] = sentinel, | |
| force_close: bool = False, | |
| limit: int = 100, | |
| limit_per_host: int = 0, | |
| enable_cleanup_closed: bool = False, | |
| loop: Optional[asyncio.AbstractEventLoop] = None, | |
| timeout_ceil_threshold: float = 5, | |
| ) -> None: | |
| if force_close: | |
| if keepalive_timeout is not None and keepalive_timeout is not sentinel: | |
| raise ValueError( | |
| "keepalive_timeout cannot be set if force_close is True" | |
| ) | |
| else: | |
| if keepalive_timeout is sentinel: | |
| keepalive_timeout = 15.0 | |
| loop = loop or asyncio.get_running_loop() | |
| self._timeout_ceil_threshold = timeout_ceil_threshold | |
| self._closed = False | |
| if loop.get_debug(): | |
| self._source_traceback = traceback.extract_stack(sys._getframe(1)) | |
| # Connection pool of reusable connections. | |
| # We use a deque to store connections because it has O(1) popleft() | |
| # and O(1) append() operations to implement a FIFO queue. | |
| self._conns: DefaultDict[ | |
| ConnectionKey, Deque[Tuple[ResponseHandler, float]] | |
| ] = defaultdict(deque) | |
| self._limit = limit | |
| self._limit_per_host = limit_per_host | |
| self._acquired: Set[ResponseHandler] = set() | |
| self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( | |
| defaultdict(set) | |
| ) | |
| self._keepalive_timeout = cast(float, keepalive_timeout) | |
| self._force_close = force_close | |
| # {host_key: FIFO list of waiters} | |
| # The FIFO is implemented with an OrderedDict with None keys because | |
| # python does not have an ordered set. | |
| self._waiters: DefaultDict[ | |
| ConnectionKey, OrderedDict[asyncio.Future[None], None] | |
| ] = defaultdict(OrderedDict) | |
| self._loop = loop | |
| self._factory = functools.partial(ResponseHandler, loop=loop) | |
| # start keep-alive connection cleanup task | |
| self._cleanup_handle: Optional[asyncio.TimerHandle] = None | |
| # start cleanup closed transports task | |
| self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None | |
| if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED: | |
| warnings.warn( | |
| "enable_cleanup_closed ignored because " | |
| "https://github.com/python/cpython/pull/118960 is fixed " | |
| f"in Python version {sys.version_info}", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| enable_cleanup_closed = False | |
| self._cleanup_closed_disabled = not enable_cleanup_closed | |
| self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = [] | |
| self._cleanup_closed() | |
| def __del__(self, _warnings: Any = warnings) -> None: | |
| if self._closed: | |
| return | |
| if not self._conns: | |
| return | |
| conns = [repr(c) for c in self._conns.values()] | |
| self._close() | |
| kwargs = {"source": self} | |
| _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs) | |
| context = { | |
| "connector": self, | |
| "connections": conns, | |
| "message": "Unclosed connector", | |
| } | |
| if self._source_traceback is not None: | |
| context["source_traceback"] = self._source_traceback | |
| self._loop.call_exception_handler(context) | |
| def __enter__(self) -> "BaseConnector": | |
| warnings.warn( | |
| '"with Connector():" is deprecated, ' | |
| 'use "async with Connector():" instead', | |
| DeprecationWarning, | |
| ) | |
| return self | |
| def __exit__(self, *exc: Any) -> None: | |
| self._close() | |
| async def __aenter__(self) -> "BaseConnector": | |
| return self | |
| async def __aexit__( | |
| self, | |
| exc_type: Optional[Type[BaseException]] = None, | |
| exc_value: Optional[BaseException] = None, | |
| exc_traceback: Optional[TracebackType] = None, | |
| ) -> None: | |
| await self.close() | |
| def force_close(self) -> bool: | |
| """Ultimately close connection on releasing if True.""" | |
| return self._force_close | |
| def limit(self) -> int: | |
| """The total number for simultaneous connections. | |
| If limit is 0 the connector has no limit. | |
| The default limit size is 100. | |
| """ | |
| return self._limit | |
| def limit_per_host(self) -> int: | |
| """The limit for simultaneous connections to the same endpoint. | |
| Endpoints are the same if they are have equal | |
| (host, port, is_ssl) triple. | |
| """ | |
| return self._limit_per_host | |
| def _cleanup(self) -> None: | |
| """Cleanup unused transports.""" | |
| if self._cleanup_handle: | |
| self._cleanup_handle.cancel() | |
| # _cleanup_handle should be unset, otherwise _release() will not | |
| # recreate it ever! | |
| self._cleanup_handle = None | |
| now = monotonic() | |
| timeout = self._keepalive_timeout | |
| if self._conns: | |
| connections = defaultdict(deque) | |
| deadline = now - timeout | |
| for key, conns in self._conns.items(): | |
| alive: Deque[Tuple[ResponseHandler, float]] = deque() | |
| for proto, use_time in conns: | |
| if proto.is_connected() and use_time - deadline >= 0: | |
| alive.append((proto, use_time)) | |
| continue | |
| transport = proto.transport | |
| proto.close() | |
| if not self._cleanup_closed_disabled and key.is_ssl: | |
| self._cleanup_closed_transports.append(transport) | |
| if alive: | |
| connections[key] = alive | |
| self._conns = connections | |
| if self._conns: | |
| self._cleanup_handle = helpers.weakref_handle( | |
| self, | |
| "_cleanup", | |
| timeout, | |
| self._loop, | |
| timeout_ceil_threshold=self._timeout_ceil_threshold, | |
| ) | |
| def _cleanup_closed(self) -> None: | |
| """Double confirmation for transport close. | |
| Some broken ssl servers may leave socket open without proper close. | |
| """ | |
| if self._cleanup_closed_handle: | |
| self._cleanup_closed_handle.cancel() | |
| for transport in self._cleanup_closed_transports: | |
| if transport is not None: | |
| transport.abort() | |
| self._cleanup_closed_transports = [] | |
| if not self._cleanup_closed_disabled: | |
| self._cleanup_closed_handle = helpers.weakref_handle( | |
| self, | |
| "_cleanup_closed", | |
| self._cleanup_closed_period, | |
| self._loop, | |
| timeout_ceil_threshold=self._timeout_ceil_threshold, | |
| ) | |
| def close(self) -> Awaitable[None]: | |
| """Close all opened transports.""" | |
| self._close() | |
| return _DeprecationWaiter(noop()) | |
| def _close(self) -> None: | |
| if self._closed: | |
| return | |
| self._closed = True | |
| try: | |
| if self._loop.is_closed(): | |
| return | |
| # cancel cleanup task | |
| if self._cleanup_handle: | |
| self._cleanup_handle.cancel() | |
| # cancel cleanup close task | |
| if self._cleanup_closed_handle: | |
| self._cleanup_closed_handle.cancel() | |
| for data in self._conns.values(): | |
| for proto, t0 in data: | |
| proto.close() | |
| for proto in self._acquired: | |
| proto.close() | |
| for transport in self._cleanup_closed_transports: | |
| if transport is not None: | |
| transport.abort() | |
| finally: | |
| self._conns.clear() | |
| self._acquired.clear() | |
| for keyed_waiters in self._waiters.values(): | |
| for keyed_waiter in keyed_waiters: | |
| keyed_waiter.cancel() | |
| self._waiters.clear() | |
| self._cleanup_handle = None | |
| self._cleanup_closed_transports.clear() | |
| self._cleanup_closed_handle = None | |
| def closed(self) -> bool: | |
| """Is connector closed. | |
| A readonly property. | |
| """ | |
| return self._closed | |
| def _available_connections(self, key: "ConnectionKey") -> int: | |
| """ | |
| Return number of available connections. | |
| The limit, limit_per_host and the connection key are taken into account. | |
| If it returns less than 1 means that there are no connections | |
| available. | |
| """ | |
| # check total available connections | |
| # If there are no limits, this will always return 1 | |
| total_remain = 1 | |
| if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0: | |
| return total_remain | |
| # check limit per host | |
| if host_remain := self._limit_per_host: | |
| if acquired := self._acquired_per_host.get(key): | |
| host_remain -= len(acquired) | |
| if total_remain > host_remain: | |
| return host_remain | |
| return total_remain | |
| async def connect( | |
| self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" | |
| ) -> Connection: | |
| """Get from pool or create new connection.""" | |
| key = req.connection_key | |
| if (conn := await self._get(key, traces)) is not None: | |
| # If we do not have to wait and we can get a connection from the pool | |
| # we can avoid the timeout ceil logic and directly return the connection | |
| return conn | |
| async with ceil_timeout(timeout.connect, timeout.ceil_threshold): | |
| if self._available_connections(key) <= 0: | |
| await self._wait_for_available_connection(key, traces) | |
| if (conn := await self._get(key, traces)) is not None: | |
| return conn | |
| placeholder = cast(ResponseHandler, _TransportPlaceholder()) | |
| self._acquired.add(placeholder) | |
| if self._limit_per_host: | |
| self._acquired_per_host[key].add(placeholder) | |
| try: | |
| # Traces are done inside the try block to ensure that the | |
| # that the placeholder is still cleaned up if an exception | |
| # is raised. | |
| if traces: | |
| for trace in traces: | |
| await trace.send_connection_create_start() | |
| proto = await self._create_connection(req, traces, timeout) | |
| if traces: | |
| for trace in traces: | |
| await trace.send_connection_create_end() | |
| except BaseException: | |
| self._release_acquired(key, placeholder) | |
| raise | |
| else: | |
| if self._closed: | |
| proto.close() | |
| raise ClientConnectionError("Connector is closed.") | |
| # The connection was successfully created, drop the placeholder | |
| # and add the real connection to the acquired set. There should | |
| # be no awaits after the proto is added to the acquired set | |
| # to ensure that the connection is not left in the acquired set | |
| # on cancellation. | |
| self._acquired.remove(placeholder) | |
| self._acquired.add(proto) | |
| if self._limit_per_host: | |
| acquired_per_host = self._acquired_per_host[key] | |
| acquired_per_host.remove(placeholder) | |
| acquired_per_host.add(proto) | |
| return Connection(self, key, proto, self._loop) | |
| async def _wait_for_available_connection( | |
| self, key: "ConnectionKey", traces: List["Trace"] | |
| ) -> None: | |
| """Wait for an available connection slot.""" | |
| # We loop here because there is a race between | |
| # the connection limit check and the connection | |
| # being acquired. If the connection is acquired | |
| # between the check and the await statement, we | |
| # need to loop again to check if the connection | |
| # slot is still available. | |
| attempts = 0 | |
| while True: | |
| fut: asyncio.Future[None] = self._loop.create_future() | |
| keyed_waiters = self._waiters[key] | |
| keyed_waiters[fut] = None | |
| if attempts: | |
| # If we have waited before, we need to move the waiter | |
| # to the front of the queue as otherwise we might get | |
| # starved and hit the timeout. | |
| keyed_waiters.move_to_end(fut, last=False) | |
| try: | |
| # Traces happen in the try block to ensure that the | |
| # the waiter is still cleaned up if an exception is raised. | |
| if traces: | |
| for trace in traces: | |
| await trace.send_connection_queued_start() | |
| await fut | |
| if traces: | |
| for trace in traces: | |
| await trace.send_connection_queued_end() | |
| finally: | |
| # pop the waiter from the queue if its still | |
| # there and not already removed by _release_waiter | |
| keyed_waiters.pop(fut, None) | |
| if not self._waiters.get(key, True): | |
| del self._waiters[key] | |
| if self._available_connections(key) > 0: | |
| break | |
| attempts += 1 | |
| async def _get( | |
| self, key: "ConnectionKey", traces: List["Trace"] | |
| ) -> Optional[Connection]: | |
| """Get next reusable connection for the key or None. | |
| The connection will be marked as acquired. | |
| """ | |
| if (conns := self._conns.get(key)) is None: | |
| return None | |
| t1 = monotonic() | |
| while conns: | |
| proto, t0 = conns.popleft() | |
| # We will we reuse the connection if its connected and | |
| # the keepalive timeout has not been exceeded | |
| if proto.is_connected() and t1 - t0 <= self._keepalive_timeout: | |
| if not conns: | |
| # The very last connection was reclaimed: drop the key | |
| del self._conns[key] | |
| self._acquired.add(proto) | |
| if self._limit_per_host: | |
| self._acquired_per_host[key].add(proto) | |
| if traces: | |
| for trace in traces: | |
| try: | |
| await trace.send_connection_reuseconn() | |
| except BaseException: | |
| self._release_acquired(key, proto) | |
| raise | |
| return Connection(self, key, proto, self._loop) | |
| # Connection cannot be reused, close it | |
| transport = proto.transport | |
| proto.close() | |
| # only for SSL transports | |
| if not self._cleanup_closed_disabled and key.is_ssl: | |
| self._cleanup_closed_transports.append(transport) | |
| # No more connections: drop the key | |
| del self._conns[key] | |
| return None | |
| def _release_waiter(self) -> None: | |
| """ | |
| Iterates over all waiters until one to be released is found. | |
| The one to be released is not finished and | |
| belongs to a host that has available connections. | |
| """ | |
| if not self._waiters: | |
| return | |
| # Having the dict keys ordered this avoids to iterate | |
| # at the same order at each call. | |
| queues = list(self._waiters) | |
| random.shuffle(queues) | |
| for key in queues: | |
| if self._available_connections(key) < 1: | |
| continue | |
| waiters = self._waiters[key] | |
| while waiters: | |
| waiter, _ = waiters.popitem(last=False) | |
| if not waiter.done(): | |
| waiter.set_result(None) | |
| return | |
| def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None: | |
| """Release acquired connection.""" | |
| if self._closed: | |
| # acquired connection is already released on connector closing | |
| return | |
| self._acquired.discard(proto) | |
| if self._limit_per_host and (conns := self._acquired_per_host.get(key)): | |
| conns.discard(proto) | |
| if not conns: | |
| del self._acquired_per_host[key] | |
| self._release_waiter() | |
| def _release( | |
| self, | |
| key: "ConnectionKey", | |
| protocol: ResponseHandler, | |
| *, | |
| should_close: bool = False, | |
| ) -> None: | |
| if self._closed: | |
| # acquired connection is already released on connector closing | |
| return | |
| self._release_acquired(key, protocol) | |
| if self._force_close or should_close or protocol.should_close: | |
| transport = protocol.transport | |
| protocol.close() | |
| if key.is_ssl and not self._cleanup_closed_disabled: | |
| self._cleanup_closed_transports.append(transport) | |
| return | |
| self._conns[key].append((protocol, monotonic())) | |
| if self._cleanup_handle is None: | |
| self._cleanup_handle = helpers.weakref_handle( | |
| self, | |
| "_cleanup", | |
| self._keepalive_timeout, | |
| self._loop, | |
| timeout_ceil_threshold=self._timeout_ceil_threshold, | |
| ) | |
| async def _create_connection( | |
| self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" | |
| ) -> ResponseHandler: | |
| raise NotImplementedError() | |
| class _DNSCacheTable: | |
| def __init__(self, ttl: Optional[float] = None) -> None: | |
| self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} | |
| self._timestamps: Dict[Tuple[str, int], float] = {} | |
| self._ttl = ttl | |
| def __contains__(self, host: object) -> bool: | |
| return host in self._addrs_rr | |
| def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: | |
| self._addrs_rr[key] = (cycle(addrs), len(addrs)) | |
| if self._ttl is not None: | |
| self._timestamps[key] = monotonic() | |
| def remove(self, key: Tuple[str, int]) -> None: | |
| self._addrs_rr.pop(key, None) | |
| if self._ttl is not None: | |
| self._timestamps.pop(key, None) | |
| def clear(self) -> None: | |
| self._addrs_rr.clear() | |
| self._timestamps.clear() | |
| def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: | |
| loop, length = self._addrs_rr[key] | |
| addrs = list(islice(loop, length)) | |
| # Consume one more element to shift internal state of `cycle` | |
| next(loop) | |
| return addrs | |
| def expired(self, key: Tuple[str, int]) -> bool: | |
| if self._ttl is None: | |
| return False | |
| return self._timestamps[key] + self._ttl < monotonic() | |
| def _make_ssl_context(verified: bool) -> SSLContext: | |
| """Create SSL context. | |
| This method is not async-friendly and should be called from a thread | |
| because it will load certificates from disk and do other blocking I/O. | |
| """ | |
| if ssl is None: | |
| # No ssl support | |
| return None | |
| if verified: | |
| sslcontext = ssl.create_default_context() | |
| else: | |
| sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) | |
| sslcontext.options |= ssl.OP_NO_SSLv2 | |
| sslcontext.options |= ssl.OP_NO_SSLv3 | |
| sslcontext.check_hostname = False | |
| sslcontext.verify_mode = ssl.CERT_NONE | |
| sslcontext.options |= ssl.OP_NO_COMPRESSION | |
| sslcontext.set_default_verify_paths() | |
| sslcontext.set_alpn_protocols(("http/1.1",)) | |
| return sslcontext | |
| # The default SSLContext objects are created at import time | |
| # since they do blocking I/O to load certificates from disk, | |
| # and imports should always be done before the event loop starts | |
| # or in a thread. | |
| _SSL_CONTEXT_VERIFIED = _make_ssl_context(True) | |
| _SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False) | |
| class TCPConnector(BaseConnector): | |
| """TCP connector. | |
| verify_ssl - Set to True to check ssl certifications. | |
| fingerprint - Pass the binary sha256 | |
| digest of the expected certificate in DER format to verify | |
| that the certificate the server presents matches. See also | |
| https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning | |
| resolver - Enable DNS lookups and use this | |
| resolver | |
| use_dns_cache - Use memory cache for DNS lookups. | |
| ttl_dns_cache - Max seconds having cached a DNS entry, None forever. | |
| family - socket address family | |
| local_addr - local tuple of (host, port) to bind socket to | |
| keepalive_timeout - (optional) Keep-alive timeout. | |
| force_close - Set to True to force close and do reconnect | |
| after each request (and between redirects). | |
| limit - The total number of simultaneous connections. | |
| limit_per_host - Number of simultaneous connections to one host. | |
| enable_cleanup_closed - Enables clean-up closed ssl transports. | |
| Disabled by default. | |
| happy_eyeballs_delay - This is the “Connection Attempt Delay” | |
| as defined in RFC 8305. To disable | |
| the happy eyeballs algorithm, set to None. | |
| interleave - “First Address Family Count” as defined in RFC 8305 | |
| loop - Optional event loop. | |
| """ | |
| allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) | |
| def __init__( | |
| self, | |
| *, | |
| verify_ssl: bool = True, | |
| fingerprint: Optional[bytes] = None, | |
| use_dns_cache: bool = True, | |
| ttl_dns_cache: Optional[int] = 10, | |
| family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, | |
| ssl_context: Optional[SSLContext] = None, | |
| ssl: Union[bool, Fingerprint, SSLContext] = True, | |
| local_addr: Optional[Tuple[str, int]] = None, | |
| resolver: Optional[AbstractResolver] = None, | |
| keepalive_timeout: Union[None, float, object] = sentinel, | |
| force_close: bool = False, | |
| limit: int = 100, | |
| limit_per_host: int = 0, | |
| enable_cleanup_closed: bool = False, | |
| loop: Optional[asyncio.AbstractEventLoop] = None, | |
| timeout_ceil_threshold: float = 5, | |
| happy_eyeballs_delay: Optional[float] = 0.25, | |
| interleave: Optional[int] = None, | |
| ): | |
| super().__init__( | |
| keepalive_timeout=keepalive_timeout, | |
| force_close=force_close, | |
| limit=limit, | |
| limit_per_host=limit_per_host, | |
| enable_cleanup_closed=enable_cleanup_closed, | |
| loop=loop, | |
| timeout_ceil_threshold=timeout_ceil_threshold, | |
| ) | |
| self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) | |
| if resolver is None: | |
| resolver = DefaultResolver(loop=self._loop) | |
| self._resolver = resolver | |
| self._use_dns_cache = use_dns_cache | |
| self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) | |
| self._throttle_dns_futures: Dict[ | |
| Tuple[str, int], Set["asyncio.Future[None]"] | |
| ] = {} | |
| self._family = family | |
| self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) | |
| self._happy_eyeballs_delay = happy_eyeballs_delay | |
| self._interleave = interleave | |
| self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() | |
| def close(self) -> Awaitable[None]: | |
| """Close all ongoing DNS calls.""" | |
| for fut in chain.from_iterable(self._throttle_dns_futures.values()): | |
| fut.cancel() | |
| for t in self._resolve_host_tasks: | |
| t.cancel() | |
| return super().close() | |
| def family(self) -> int: | |
| """Socket family like AF_INET.""" | |
| return self._family | |
| def use_dns_cache(self) -> bool: | |
| """True if local DNS caching is enabled.""" | |
| return self._use_dns_cache | |
| def clear_dns_cache( | |
| self, host: Optional[str] = None, port: Optional[int] = None | |
| ) -> None: | |
| """Remove specified host/port or clear all dns local cache.""" | |
| if host is not None and port is not None: | |
| self._cached_hosts.remove((host, port)) | |
| elif host is not None or port is not None: | |
| raise ValueError("either both host and port or none of them are allowed") | |
| else: | |
| self._cached_hosts.clear() | |
| async def _resolve_host( | |
| self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None | |
| ) -> List[ResolveResult]: | |
| """Resolve host and return list of addresses.""" | |
| if is_ip_address(host): | |
| return [ | |
| { | |
| "hostname": host, | |
| "host": host, | |
| "port": port, | |
| "family": self._family, | |
| "proto": 0, | |
| "flags": 0, | |
| } | |
| ] | |
| if not self._use_dns_cache: | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_resolvehost_start(host) | |
| res = await self._resolver.resolve(host, port, family=self._family) | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_resolvehost_end(host) | |
| return res | |
| key = (host, port) | |
| if key in self._cached_hosts and not self._cached_hosts.expired(key): | |
| # get result early, before any await (#4014) | |
| result = self._cached_hosts.next_addrs(key) | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_cache_hit(host) | |
| return result | |
| futures: Set["asyncio.Future[None]"] | |
| # | |
| # If multiple connectors are resolving the same host, we wait | |
| # for the first one to resolve and then use the result for all of them. | |
| # We use a throttle to ensure that we only resolve the host once | |
| # and then use the result for all the waiters. | |
| # | |
| if key in self._throttle_dns_futures: | |
| # get futures early, before any await (#4014) | |
| futures = self._throttle_dns_futures[key] | |
| future: asyncio.Future[None] = self._loop.create_future() | |
| futures.add(future) | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_cache_hit(host) | |
| try: | |
| await future | |
| finally: | |
| futures.discard(future) | |
| return self._cached_hosts.next_addrs(key) | |
| # update dict early, before any await (#4014) | |
| self._throttle_dns_futures[key] = futures = set() | |
| # In this case we need to create a task to ensure that we can shield | |
| # the task from cancellation as cancelling this lookup should not cancel | |
| # the underlying lookup or else the cancel event will get broadcast to | |
| # all the waiters across all connections. | |
| # | |
| coro = self._resolve_host_with_throttle(key, host, port, futures, traces) | |
| loop = asyncio.get_running_loop() | |
| if sys.version_info >= (3, 12): | |
| # Optimization for Python 3.12, try to send immediately | |
| resolved_host_task = asyncio.Task(coro, loop=loop, eager_start=True) | |
| else: | |
| resolved_host_task = loop.create_task(coro) | |
| if not resolved_host_task.done(): | |
| self._resolve_host_tasks.add(resolved_host_task) | |
| resolved_host_task.add_done_callback(self._resolve_host_tasks.discard) | |
| try: | |
| return await asyncio.shield(resolved_host_task) | |
| except asyncio.CancelledError: | |
| def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: | |
| with suppress(Exception, asyncio.CancelledError): | |
| fut.result() | |
| resolved_host_task.add_done_callback(drop_exception) | |
| raise | |
| async def _resolve_host_with_throttle( | |
| self, | |
| key: Tuple[str, int], | |
| host: str, | |
| port: int, | |
| futures: Set["asyncio.Future[None]"], | |
| traces: Optional[Sequence["Trace"]], | |
| ) -> List[ResolveResult]: | |
| """Resolve host and set result for all waiters. | |
| This method must be run in a task and shielded from cancellation | |
| to avoid cancelling the underlying lookup. | |
| """ | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_cache_miss(host) | |
| try: | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_resolvehost_start(host) | |
| addrs = await self._resolver.resolve(host, port, family=self._family) | |
| if traces: | |
| for trace in traces: | |
| await trace.send_dns_resolvehost_end(host) | |
| self._cached_hosts.add(key, addrs) | |
| for fut in futures: | |
| set_result(fut, None) | |
| except BaseException as e: | |
| # any DNS exception is set for the waiters to raise the same exception. | |
| # This coro is always run in task that is shielded from cancellation so | |
| # we should never be propagating cancellation here. | |
| for fut in futures: | |
| set_exception(fut, e) | |
| raise | |
| finally: | |
| self._throttle_dns_futures.pop(key) | |
| return self._cached_hosts.next_addrs(key) | |
| async def _create_connection( | |
| self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" | |
| ) -> ResponseHandler: | |
| """Create connection. | |
| Has same keyword arguments as BaseEventLoop.create_connection. | |
| """ | |
| if req.proxy: | |
| _, proto = await self._create_proxy_connection(req, traces, timeout) | |
| else: | |
| _, proto = await self._create_direct_connection(req, traces, timeout) | |
| return proto | |
| def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: | |
| """Logic to get the correct SSL context | |
| 0. if req.ssl is false, return None | |
| 1. if ssl_context is specified in req, use it | |
| 2. if _ssl_context is specified in self, use it | |
| 3. otherwise: | |
| 1. if verify_ssl is not specified in req, use self.ssl_context | |
| (will generate a default context according to self.verify_ssl) | |
| 2. if verify_ssl is True in req, generate a default SSL context | |
| 3. if verify_ssl is False in req, generate a SSL context that | |
| won't verify | |
| """ | |
| if not req.is_ssl(): | |
| return None | |
| if ssl is None: # pragma: no cover | |
| raise RuntimeError("SSL is not supported.") | |
| sslcontext = req.ssl | |
| if isinstance(sslcontext, ssl.SSLContext): | |
| return sslcontext | |
| if sslcontext is not True: | |
| # not verified or fingerprinted | |
| return _SSL_CONTEXT_UNVERIFIED | |
| sslcontext = self._ssl | |
| if isinstance(sslcontext, ssl.SSLContext): | |
| return sslcontext | |
| if sslcontext is not True: | |
| # not verified or fingerprinted | |
| return _SSL_CONTEXT_UNVERIFIED | |
| return _SSL_CONTEXT_VERIFIED | |
| def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: | |
| ret = req.ssl | |
| if isinstance(ret, Fingerprint): | |
| return ret | |
| ret = self._ssl | |
| if isinstance(ret, Fingerprint): | |
| return ret | |
| return None | |
| async def _wrap_create_connection( | |
| self, | |
| *args: Any, | |
| addr_infos: List[aiohappyeyeballs.AddrInfoType], | |
| req: ClientRequest, | |
| timeout: "ClientTimeout", | |
| client_error: Type[Exception] = ClientConnectorError, | |
| **kwargs: Any, | |
| ) -> Tuple[asyncio.Transport, ResponseHandler]: | |
| try: | |
| async with ceil_timeout( | |
| timeout.sock_connect, ceil_threshold=timeout.ceil_threshold | |
| ): | |
| sock = await aiohappyeyeballs.start_connection( | |
| addr_infos=addr_infos, | |
| local_addr_infos=self._local_addr_infos, | |
| happy_eyeballs_delay=self._happy_eyeballs_delay, | |
| interleave=self._interleave, | |
| loop=self._loop, | |
| ) | |
| return await self._loop.create_connection(*args, **kwargs, sock=sock) | |
| except cert_errors as exc: | |
| raise ClientConnectorCertificateError(req.connection_key, exc) from exc | |
| except ssl_errors as exc: | |
| raise ClientConnectorSSLError(req.connection_key, exc) from exc | |
| except OSError as exc: | |
| if exc.errno is None and isinstance(exc, asyncio.TimeoutError): | |
| raise | |
| raise client_error(req.connection_key, exc) from exc | |
| async def _wrap_existing_connection( | |
| self, | |
| *args: Any, | |
| req: ClientRequest, | |
| timeout: "ClientTimeout", | |
| client_error: Type[Exception] = ClientConnectorError, | |
| **kwargs: Any, | |
| ) -> Tuple[asyncio.Transport, ResponseHandler]: | |
| try: | |
| async with ceil_timeout( | |
| timeout.sock_connect, ceil_threshold=timeout.ceil_threshold | |
| ): | |
| return await self._loop.create_connection(*args, **kwargs) | |
| except cert_errors as exc: | |
| raise ClientConnectorCertificateError(req.connection_key, exc) from exc | |
| except ssl_errors as exc: | |
| raise ClientConnectorSSLError(req.connection_key, exc) from exc | |
| except OSError as exc: | |
| if exc.errno is None and isinstance(exc, asyncio.TimeoutError): | |
| raise | |
| raise client_error(req.connection_key, exc) from exc | |
| def _fail_on_no_start_tls(self, req: "ClientRequest") -> None: | |
| """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``. | |
| It is necessary for TLS-in-TLS so that it is possible to | |
| send HTTPS queries through HTTPS proxies. | |
| This doesn't affect regular HTTP requests, though. | |
| """ | |
| if not req.is_ssl(): | |
| return | |
| proxy_url = req.proxy | |
| assert proxy_url is not None | |
| if proxy_url.scheme != "https": | |
| return | |
| self._check_loop_for_start_tls() | |
| def _check_loop_for_start_tls(self) -> None: | |
| try: | |
| self._loop.start_tls | |
| except AttributeError as attr_exc: | |
| raise RuntimeError( | |
| "An HTTPS request is being sent through an HTTPS proxy. " | |
| "This needs support for TLS in TLS but it is not implemented " | |
| "in your runtime for the stdlib asyncio.\n\n" | |
| "Please upgrade to Python 3.11 or higher. For more details, " | |
| "please see:\n" | |
| "* https://bugs.python.org/issue37179\n" | |
| "* https://github.com/python/cpython/pull/28073\n" | |
| "* https://docs.aiohttp.org/en/stable/" | |
| "client_advanced.html#proxy-support\n" | |
| "* https://github.com/aio-libs/aiohttp/discussions/6044\n", | |
| ) from attr_exc | |
| def _loop_supports_start_tls(self) -> bool: | |
| try: | |
| self._check_loop_for_start_tls() | |
| except RuntimeError: | |
| return False | |
| else: | |
| return True | |
| def _warn_about_tls_in_tls( | |
| self, | |
| underlying_transport: asyncio.Transport, | |
| req: ClientRequest, | |
| ) -> None: | |
| """Issue a warning if the requested URL has HTTPS scheme.""" | |
| if req.request_info.url.scheme != "https": | |
| return | |
| asyncio_supports_tls_in_tls = getattr( | |
| underlying_transport, | |
| "_start_tls_compatible", | |
| False, | |
| ) | |
| if asyncio_supports_tls_in_tls: | |
| return | |
| warnings.warn( | |
| "An HTTPS request is being sent through an HTTPS proxy. " | |
| "This support for TLS in TLS is known to be disabled " | |
| "in the stdlib asyncio (Python <3.11). This is why you'll probably see " | |
| "an error in the log below.\n\n" | |
| "It is possible to enable it via monkeypatching. " | |
| "For more details, see:\n" | |
| "* https://bugs.python.org/issue37179\n" | |
| "* https://github.com/python/cpython/pull/28073\n\n" | |
| "You can temporarily patch this as follows:\n" | |
| "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n" | |
| "* https://github.com/aio-libs/aiohttp/discussions/6044\n", | |
| RuntimeWarning, | |
| source=self, | |
| # Why `4`? At least 3 of the calls in the stack originate | |
| # from the methods in this class. | |
| stacklevel=3, | |
| ) | |
| async def _start_tls_connection( | |
| self, | |
| underlying_transport: asyncio.Transport, | |
| req: ClientRequest, | |
| timeout: "ClientTimeout", | |
| client_error: Type[Exception] = ClientConnectorError, | |
| ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: | |
| """Wrap the raw TCP transport with TLS.""" | |
| tls_proto = self._factory() # Create a brand new proto for TLS | |
| sslcontext = self._get_ssl_context(req) | |
| if TYPE_CHECKING: | |
| # _start_tls_connection is unreachable in the current code path | |
| # if sslcontext is None. | |
| assert sslcontext is not None | |
| try: | |
| async with ceil_timeout( | |
| timeout.sock_connect, ceil_threshold=timeout.ceil_threshold | |
| ): | |
| try: | |
| tls_transport = await self._loop.start_tls( | |
| underlying_transport, | |
| tls_proto, | |
| sslcontext, | |
| server_hostname=req.server_hostname or req.host, | |
| ssl_handshake_timeout=timeout.total, | |
| ) | |
| except BaseException: | |
| # We need to close the underlying transport since | |
| # `start_tls()` probably failed before it had a | |
| # chance to do this: | |
| underlying_transport.close() | |
| raise | |
| if isinstance(tls_transport, asyncio.Transport): | |
| fingerprint = self._get_fingerprint(req) | |
| if fingerprint: | |
| try: | |
| fingerprint.check(tls_transport) | |
| except ServerFingerprintMismatch: | |
| tls_transport.close() | |
| if not self._cleanup_closed_disabled: | |
| self._cleanup_closed_transports.append(tls_transport) | |
| raise | |
| except cert_errors as exc: | |
| raise ClientConnectorCertificateError(req.connection_key, exc) from exc | |
| except ssl_errors as exc: | |
| raise ClientConnectorSSLError(req.connection_key, exc) from exc | |
| except OSError as exc: | |
| if exc.errno is None and isinstance(exc, asyncio.TimeoutError): | |
| raise | |
| raise client_error(req.connection_key, exc) from exc | |
| except TypeError as type_err: | |
| # Example cause looks like this: | |
| # TypeError: transport <asyncio.sslproto._SSLProtocolTransport | |
| # object at 0x7f760615e460> is not supported by start_tls() | |
| raise ClientConnectionError( | |
| "Cannot initialize a TLS-in-TLS connection to host " | |
| f"{req.host!s}:{req.port:d} through an underlying connection " | |
| f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " | |
| f"[{type_err!s}]" | |
| ) from type_err | |
| else: | |
| if tls_transport is None: | |
| msg = "Failed to start TLS (possibly caused by closing transport)" | |
| raise client_error(req.connection_key, OSError(msg)) | |
| tls_proto.connection_made( | |
| tls_transport | |
| ) # Kick the state machine of the new TLS protocol | |
| return tls_transport, tls_proto | |
| def _convert_hosts_to_addr_infos( | |
| self, hosts: List[ResolveResult] | |
| ) -> List[aiohappyeyeballs.AddrInfoType]: | |
| """Converts the list of hosts to a list of addr_infos. | |
| The list of hosts is the result of a DNS lookup. The list of | |
| addr_infos is the result of a call to `socket.getaddrinfo()`. | |
| """ | |
| addr_infos: List[aiohappyeyeballs.AddrInfoType] = [] | |
| for hinfo in hosts: | |
| host = hinfo["host"] | |
| is_ipv6 = ":" in host | |
| family = socket.AF_INET6 if is_ipv6 else socket.AF_INET | |
| if self._family and self._family != family: | |
| continue | |
| addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"]) | |
| addr_infos.append( | |
| (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr) | |
| ) | |
| return addr_infos | |
| async def _create_direct_connection( | |
| self, | |
| req: ClientRequest, | |
| traces: List["Trace"], | |
| timeout: "ClientTimeout", | |
| *, | |
| client_error: Type[Exception] = ClientConnectorError, | |
| ) -> Tuple[asyncio.Transport, ResponseHandler]: | |
| sslcontext = self._get_ssl_context(req) | |
| fingerprint = self._get_fingerprint(req) | |
| host = req.url.raw_host | |
| assert host is not None | |
| # Replace multiple trailing dots with a single one. | |
| # A trailing dot is only present for fully-qualified domain names. | |
| # See https://github.com/aio-libs/aiohttp/pull/7364. | |
| if host.endswith(".."): | |
| host = host.rstrip(".") + "." | |
| port = req.port | |
| assert port is not None | |
| try: | |
| # Cancelling this lookup should not cancel the underlying lookup | |
| # or else the cancel event will get broadcast to all the waiters | |
| # across all connections. | |
| hosts = await self._resolve_host(host, port, traces=traces) | |
| except OSError as exc: | |
| if exc.errno is None and isinstance(exc, asyncio.TimeoutError): | |
| raise | |
| # in case of proxy it is not ClientProxyConnectionError | |
| # it is problem of resolving proxy ip itself | |
| raise ClientConnectorDNSError(req.connection_key, exc) from exc | |
| last_exc: Optional[Exception] = None | |
| addr_infos = self._convert_hosts_to_addr_infos(hosts) | |
| while addr_infos: | |
| # Strip trailing dots, certificates contain FQDN without dots. | |
| # See https://github.com/aio-libs/aiohttp/issues/3636 | |
| server_hostname = ( | |
| (req.server_hostname or host).rstrip(".") if sslcontext else None | |
| ) | |
| try: | |
| transp, proto = await self._wrap_create_connection( | |
| self._factory, | |
| timeout=timeout, | |
| ssl=sslcontext, | |
| addr_infos=addr_infos, | |
| server_hostname=server_hostname, | |
| req=req, | |
| client_error=client_error, | |
| ) | |
| except (ClientConnectorError, asyncio.TimeoutError) as exc: | |
| last_exc = exc | |
| aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) | |
| continue | |
| if req.is_ssl() and fingerprint: | |
| try: | |
| fingerprint.check(transp) | |
| except ServerFingerprintMismatch as exc: | |
| transp.close() | |
| if not self._cleanup_closed_disabled: | |
| self._cleanup_closed_transports.append(transp) | |
| last_exc = exc | |
| # Remove the bad peer from the list of addr_infos | |
| sock: socket.socket = transp.get_extra_info("socket") | |
| bad_peer = sock.getpeername() | |
| aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer) | |
| continue | |
| return transp, proto | |
| else: | |
| assert last_exc is not None | |
| raise last_exc | |
| async def _create_proxy_connection( | |
| self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" | |
| ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: | |
| self._fail_on_no_start_tls(req) | |
| runtime_has_start_tls = self._loop_supports_start_tls() | |
| headers: Dict[str, str] = {} | |
| if req.proxy_headers is not None: | |
| headers = req.proxy_headers # type: ignore[assignment] | |
| headers[hdrs.HOST] = req.headers[hdrs.HOST] | |
| url = req.proxy | |
| assert url is not None | |
| proxy_req = ClientRequest( | |
| hdrs.METH_GET, | |
| url, | |
| headers=headers, | |
| auth=req.proxy_auth, | |
| loop=self._loop, | |
| ssl=req.ssl, | |
| ) | |
| # create connection to proxy server | |
| transport, proto = await self._create_direct_connection( | |
| proxy_req, [], timeout, client_error=ClientProxyConnectionError | |
| ) | |
| auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) | |
| if auth is not None: | |
| if not req.is_ssl(): | |
| req.headers[hdrs.PROXY_AUTHORIZATION] = auth | |
| else: | |
| proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth | |
| if req.is_ssl(): | |
| if runtime_has_start_tls: | |
| self._warn_about_tls_in_tls(transport, req) | |
| # For HTTPS requests over HTTP proxy | |
| # we must notify proxy to tunnel connection | |
| # so we send CONNECT command: | |
| # CONNECT www.python.org:443 HTTP/1.1 | |
| # Host: www.python.org | |
| # | |
| # next we must do TLS handshake and so on | |
| # to do this we must wrap raw socket into secure one | |
| # asyncio handles this perfectly | |
| proxy_req.method = hdrs.METH_CONNECT | |
| proxy_req.url = req.url | |
| key = req.connection_key._replace( | |
| proxy=None, proxy_auth=None, proxy_headers_hash=None | |
| ) | |
| conn = Connection(self, key, proto, self._loop) | |
| proxy_resp = await proxy_req.send(conn) | |
| try: | |
| protocol = conn._protocol | |
| assert protocol is not None | |
| # read_until_eof=True will ensure the connection isn't closed | |
| # once the response is received and processed allowing | |
| # START_TLS to work on the connection below. | |
| protocol.set_response_params( | |
| read_until_eof=runtime_has_start_tls, | |
| timeout_ceil_threshold=self._timeout_ceil_threshold, | |
| ) | |
| resp = await proxy_resp.start(conn) | |
| except BaseException: | |
| proxy_resp.close() | |
| conn.close() | |
| raise | |
| else: | |
| conn._protocol = None | |
| try: | |
| if resp.status != 200: | |
| message = resp.reason | |
| if message is None: | |
| message = HTTPStatus(resp.status).phrase | |
| raise ClientHttpProxyError( | |
| proxy_resp.request_info, | |
| resp.history, | |
| status=resp.status, | |
| message=message, | |
| headers=resp.headers, | |
| ) | |
| if not runtime_has_start_tls: | |
| rawsock = transport.get_extra_info("socket", default=None) | |
| if rawsock is None: | |
| raise RuntimeError( | |
| "Transport does not expose socket instance" | |
| ) | |
| # Duplicate the socket, so now we can close proxy transport | |
| rawsock = rawsock.dup() | |
| except BaseException: | |
| # It shouldn't be closed in `finally` because it's fed to | |
| # `loop.start_tls()` and the docs say not to touch it after | |
| # passing there. | |
| transport.close() | |
| raise | |
| finally: | |
| if not runtime_has_start_tls: | |
| transport.close() | |
| if not runtime_has_start_tls: | |
| # HTTP proxy with support for upgrade to HTTPS | |
| sslcontext = self._get_ssl_context(req) | |
| return await self._wrap_existing_connection( | |
| self._factory, | |
| timeout=timeout, | |
| ssl=sslcontext, | |
| sock=rawsock, | |
| server_hostname=req.host, | |
| req=req, | |
| ) | |
| return await self._start_tls_connection( | |
| # Access the old transport for the last time before it's | |
| # closed and forgotten forever: | |
| transport, | |
| req=req, | |
| timeout=timeout, | |
| ) | |
| finally: | |
| proxy_resp.close() | |
| return transport, proto | |
| class UnixConnector(BaseConnector): | |
| """Unix socket connector. | |
| path - Unix socket path. | |
| keepalive_timeout - (optional) Keep-alive timeout. | |
| force_close - Set to True to force close and do reconnect | |
| after each request (and between redirects). | |
| limit - The total number of simultaneous connections. | |
| limit_per_host - Number of simultaneous connections to one host. | |
| loop - Optional event loop. | |
| """ | |
| allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) | |
| def __init__( | |
| self, | |
| path: str, | |
| force_close: bool = False, | |
| keepalive_timeout: Union[object, float, None] = sentinel, | |
| limit: int = 100, | |
| limit_per_host: int = 0, | |
| loop: Optional[asyncio.AbstractEventLoop] = None, | |
| ) -> None: | |
| super().__init__( | |
| force_close=force_close, | |
| keepalive_timeout=keepalive_timeout, | |
| limit=limit, | |
| limit_per_host=limit_per_host, | |
| loop=loop, | |
| ) | |
| self._path = path | |
| def path(self) -> str: | |
| """Path to unix socket.""" | |
| return self._path | |
| async def _create_connection( | |
| self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" | |
| ) -> ResponseHandler: | |
| try: | |
| async with ceil_timeout( | |
| timeout.sock_connect, ceil_threshold=timeout.ceil_threshold | |
| ): | |
| _, proto = await self._loop.create_unix_connection( | |
| self._factory, self._path | |
| ) | |
| except OSError as exc: | |
| if exc.errno is None and isinstance(exc, asyncio.TimeoutError): | |
| raise | |
| raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc | |
| return proto | |
| class NamedPipeConnector(BaseConnector): | |
| """Named pipe connector. | |
| Only supported by the proactor event loop. | |
| See also: https://docs.python.org/3/library/asyncio-eventloop.html | |
| path - Windows named pipe path. | |
| keepalive_timeout - (optional) Keep-alive timeout. | |
| force_close - Set to True to force close and do reconnect | |
| after each request (and between redirects). | |
| limit - The total number of simultaneous connections. | |
| limit_per_host - Number of simultaneous connections to one host. | |
| loop - Optional event loop. | |
| """ | |
| allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) | |
| def __init__( | |
| self, | |
| path: str, | |
| force_close: bool = False, | |
| keepalive_timeout: Union[object, float, None] = sentinel, | |
| limit: int = 100, | |
| limit_per_host: int = 0, | |
| loop: Optional[asyncio.AbstractEventLoop] = None, | |
| ) -> None: | |
| super().__init__( | |
| force_close=force_close, | |
| keepalive_timeout=keepalive_timeout, | |
| limit=limit, | |
| limit_per_host=limit_per_host, | |
| loop=loop, | |
| ) | |
| if not isinstance( | |
| self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] | |
| ): | |
| raise RuntimeError( | |
| "Named Pipes only available in proactor loop under windows" | |
| ) | |
| self._path = path | |
| def path(self) -> str: | |
| """Path to the named pipe.""" | |
| return self._path | |
| async def _create_connection( | |
| self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" | |
| ) -> ResponseHandler: | |
| try: | |
| async with ceil_timeout( | |
| timeout.sock_connect, ceil_threshold=timeout.ceil_threshold | |
| ): | |
| _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined] | |
| self._factory, self._path | |
| ) | |
| # the drain is required so that the connection_made is called | |
| # and transport is set otherwise it is not set before the | |
| # `assert conn.transport is not None` | |
| # in client.py's _request method | |
| await asyncio.sleep(0) | |
| # other option is to manually set transport like | |
| # `proto.transport = trans` | |
| except OSError as exc: | |
| if exc.errno is None and isinstance(exc, asyncio.TimeoutError): | |
| raise | |
| raise ClientConnectorError(req.connection_key, exc) from exc | |
| return cast(ResponseHandler, proto) | |