| | import asyncio |
| | import socket |
| | import weakref |
| | from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union |
| |
|
| | from .abc import AbstractResolver, ResolveResult |
| |
|
| | __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") |
| |
|
| |
|
| | try: |
| | import aiodns |
| |
|
| | aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") |
| | except ImportError: |
| | aiodns = None |
| | aiodns_default = False |
| |
|
| |
|
| | _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV |
| | _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV |
| | _AI_ADDRCONFIG = socket.AI_ADDRCONFIG |
| | if hasattr(socket, "AI_MASK"): |
| | _AI_ADDRCONFIG &= socket.AI_MASK |
| |
|
| |
|
| | class ThreadedResolver(AbstractResolver): |
| | """Threaded resolver. |
| | |
| | Uses an Executor for synchronous getaddrinfo() calls. |
| | concurrent.futures.ThreadPoolExecutor is used by default. |
| | """ |
| |
|
| | def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
| | self._loop = loop or asyncio.get_running_loop() |
| |
|
| | async def resolve( |
| | self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| | ) -> List[ResolveResult]: |
| | infos = await self._loop.getaddrinfo( |
| | host, |
| | port, |
| | type=socket.SOCK_STREAM, |
| | family=family, |
| | flags=_AI_ADDRCONFIG, |
| | ) |
| |
|
| | hosts: List[ResolveResult] = [] |
| | for family, _, proto, _, address in infos: |
| | if family == socket.AF_INET6: |
| | if len(address) < 3: |
| | |
| | |
| | continue |
| | if address[3]: |
| | |
| | |
| | |
| | resolved_host, _port = await self._loop.getnameinfo( |
| | address, _NAME_SOCKET_FLAGS |
| | ) |
| | port = int(_port) |
| | else: |
| | resolved_host, port = address[:2] |
| | else: |
| | assert family == socket.AF_INET |
| | resolved_host, port = address |
| | hosts.append( |
| | ResolveResult( |
| | hostname=host, |
| | host=resolved_host, |
| | port=port, |
| | family=family, |
| | proto=proto, |
| | flags=_NUMERIC_SOCKET_FLAGS, |
| | ) |
| | ) |
| |
|
| | return hosts |
| |
|
| | async def close(self) -> None: |
| | pass |
| |
|
| |
|
| | class AsyncResolver(AbstractResolver): |
| | """Use the `aiodns` package to make asynchronous DNS lookups""" |
| |
|
| | def __init__( |
| | self, |
| | loop: Optional[asyncio.AbstractEventLoop] = None, |
| | *args: Any, |
| | **kwargs: Any, |
| | ) -> None: |
| | if aiodns is None: |
| | raise RuntimeError("Resolver requires aiodns library") |
| |
|
| | self._loop = loop or asyncio.get_running_loop() |
| | self._manager: Optional[_DNSResolverManager] = None |
| | |
| | |
| | |
| | if args or kwargs: |
| | self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| | return |
| | |
| | self._manager = _DNSResolverManager() |
| | self._resolver = self._manager.get_resolver(self, self._loop) |
| |
|
| | if not hasattr(self._resolver, "gethostbyname"): |
| | |
| | self.resolve = self._resolve_with_query |
| |
|
| | async def resolve( |
| | self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| | ) -> List[ResolveResult]: |
| | try: |
| | resp = await self._resolver.getaddrinfo( |
| | host, |
| | port=port, |
| | type=socket.SOCK_STREAM, |
| | family=family, |
| | flags=_AI_ADDRCONFIG, |
| | ) |
| | except aiodns.error.DNSError as exc: |
| | msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
| | raise OSError(None, msg) from exc |
| | hosts: List[ResolveResult] = [] |
| | for node in resp.nodes: |
| | address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr |
| | family = node.family |
| | if family == socket.AF_INET6: |
| | if len(address) > 3 and address[3]: |
| | |
| | |
| | |
| | result = await self._resolver.getnameinfo( |
| | (address[0].decode("ascii"), *address[1:]), |
| | _NAME_SOCKET_FLAGS, |
| | ) |
| | resolved_host = result.node |
| | else: |
| | resolved_host = address[0].decode("ascii") |
| | port = address[1] |
| | else: |
| | assert family == socket.AF_INET |
| | resolved_host = address[0].decode("ascii") |
| | port = address[1] |
| | hosts.append( |
| | ResolveResult( |
| | hostname=host, |
| | host=resolved_host, |
| | port=port, |
| | family=family, |
| | proto=0, |
| | flags=_NUMERIC_SOCKET_FLAGS, |
| | ) |
| | ) |
| |
|
| | if not hosts: |
| | raise OSError(None, "DNS lookup failed") |
| |
|
| | return hosts |
| |
|
| | async def _resolve_with_query( |
| | self, host: str, port: int = 0, family: int = socket.AF_INET |
| | ) -> List[Dict[str, Any]]: |
| | qtype: Final = "AAAA" if family == socket.AF_INET6 else "A" |
| |
|
| | try: |
| | resp = await self._resolver.query(host, qtype) |
| | except aiodns.error.DNSError as exc: |
| | msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
| | raise OSError(None, msg) from exc |
| |
|
| | hosts = [] |
| | for rr in resp: |
| | hosts.append( |
| | { |
| | "hostname": host, |
| | "host": rr.host, |
| | "port": port, |
| | "family": family, |
| | "proto": 0, |
| | "flags": socket.AI_NUMERICHOST, |
| | } |
| | ) |
| |
|
| | if not hosts: |
| | raise OSError(None, "DNS lookup failed") |
| |
|
| | return hosts |
| |
|
| | async def close(self) -> None: |
| | if self._manager: |
| | |
| | self._manager.release_resolver(self, self._loop) |
| | self._manager = None |
| | self._resolver = None |
| | return |
| | |
| | if self._resolver is not None: |
| | self._resolver.cancel() |
| | self._resolver = None |
| |
|
| |
|
| | class _DNSResolverManager: |
| | """Manager for aiodns.DNSResolver objects. |
| | |
| | This class manages shared aiodns.DNSResolver instances |
| | with no custom arguments across different event loops. |
| | """ |
| |
|
| | _instance: Optional["_DNSResolverManager"] = None |
| |
|
| | def __new__(cls) -> "_DNSResolverManager": |
| | if cls._instance is None: |
| | cls._instance = super().__new__(cls) |
| | cls._instance._init() |
| | return cls._instance |
| |
|
| | def _init(self) -> None: |
| | |
| | self._loop_data: weakref.WeakKeyDictionary[ |
| | asyncio.AbstractEventLoop, |
| | tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], |
| | ] = weakref.WeakKeyDictionary() |
| |
|
| | def get_resolver( |
| | self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| | ) -> "aiodns.DNSResolver": |
| | """Get or create the shared aiodns.DNSResolver instance for a specific event loop. |
| | |
| | Args: |
| | client: The AsyncResolver instance requesting the resolver. |
| | This is required to track resolver usage. |
| | loop: The event loop to use for the resolver. |
| | """ |
| | |
| | if loop not in self._loop_data: |
| | resolver = aiodns.DNSResolver(loop=loop) |
| | client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() |
| | self._loop_data[loop] = (resolver, client_set) |
| | else: |
| | |
| | resolver, client_set = self._loop_data[loop] |
| |
|
| | |
| | client_set.add(client) |
| | return resolver |
| |
|
| | def release_resolver( |
| | self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| | ) -> None: |
| | """Release the resolver for an AsyncResolver client when it's closed. |
| | |
| | Args: |
| | client: The AsyncResolver instance to release. |
| | loop: The event loop the resolver was using. |
| | """ |
| | |
| | current_loop_data = self._loop_data.get(loop) |
| | if current_loop_data is None: |
| | return |
| | resolver, client_set = current_loop_data |
| | client_set.discard(client) |
| | |
| | if not client_set: |
| | if resolver is not None: |
| | resolver.cancel() |
| | del self._loop_data[loop] |
| |
|
| |
|
| | _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] |
| | DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver |
| |
|