| import asyncio |
| import socket |
| import weakref |
| from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union |
|
|
| from .abc import AbstractResolver, ResolveResult |
|
|
| __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") |
|
|
|
|
| try: |
| import aiodns |
|
|
| aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") |
| except ImportError: |
| aiodns = None |
| aiodns_default = False |
|
|
|
|
| _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV |
| _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV |
| _AI_ADDRCONFIG = socket.AI_ADDRCONFIG |
| if hasattr(socket, "AI_MASK"): |
| _AI_ADDRCONFIG &= socket.AI_MASK |
|
|
|
|
| class ThreadedResolver(AbstractResolver): |
| """Threaded resolver. |
| |
| Uses an Executor for synchronous getaddrinfo() calls. |
| concurrent.futures.ThreadPoolExecutor is used by default. |
| """ |
|
|
| def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: |
| self._loop = loop or asyncio.get_running_loop() |
|
|
| async def resolve( |
| self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| ) -> List[ResolveResult]: |
| infos = await self._loop.getaddrinfo( |
| host, |
| port, |
| type=socket.SOCK_STREAM, |
| family=family, |
| flags=_AI_ADDRCONFIG, |
| ) |
|
|
| hosts: List[ResolveResult] = [] |
| for family, _, proto, _, address in infos: |
| if family == socket.AF_INET6: |
| if len(address) < 3: |
| |
| |
| continue |
| if address[3]: |
| |
| |
| |
| resolved_host, _port = await self._loop.getnameinfo( |
| address, _NAME_SOCKET_FLAGS |
| ) |
| port = int(_port) |
| else: |
| resolved_host, port = address[:2] |
| else: |
| assert family == socket.AF_INET |
| resolved_host, port = address |
| hosts.append( |
| ResolveResult( |
| hostname=host, |
| host=resolved_host, |
| port=port, |
| family=family, |
| proto=proto, |
| flags=_NUMERIC_SOCKET_FLAGS, |
| ) |
| ) |
|
|
| return hosts |
|
|
| async def close(self) -> None: |
| pass |
|
|
|
|
| class AsyncResolver(AbstractResolver): |
| """Use the `aiodns` package to make asynchronous DNS lookups""" |
|
|
| def __init__( |
| self, |
| loop: Optional[asyncio.AbstractEventLoop] = None, |
| *args: Any, |
| **kwargs: Any, |
| ) -> None: |
| if aiodns is None: |
| raise RuntimeError("Resolver requires aiodns library") |
|
|
| self._loop = loop or asyncio.get_running_loop() |
| self._manager: Optional[_DNSResolverManager] = None |
| |
| |
| |
| if args or kwargs: |
| self._resolver = aiodns.DNSResolver(*args, **kwargs) |
| return |
| |
| self._manager = _DNSResolverManager() |
| self._resolver = self._manager.get_resolver(self, self._loop) |
|
|
| if not hasattr(self._resolver, "gethostbyname"): |
| |
| self.resolve = self._resolve_with_query |
|
|
| async def resolve( |
| self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET |
| ) -> List[ResolveResult]: |
| try: |
| resp = await self._resolver.getaddrinfo( |
| host, |
| port=port, |
| type=socket.SOCK_STREAM, |
| family=family, |
| flags=_AI_ADDRCONFIG, |
| ) |
| except aiodns.error.DNSError as exc: |
| msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
| raise OSError(None, msg) from exc |
| hosts: List[ResolveResult] = [] |
| for node in resp.nodes: |
| address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr |
| family = node.family |
| if family == socket.AF_INET6: |
| if len(address) > 3 and address[3]: |
| |
| |
| |
| result = await self._resolver.getnameinfo( |
| (address[0].decode("ascii"), *address[1:]), |
| _NAME_SOCKET_FLAGS, |
| ) |
| resolved_host = result.node |
| else: |
| resolved_host = address[0].decode("ascii") |
| port = address[1] |
| else: |
| assert family == socket.AF_INET |
| resolved_host = address[0].decode("ascii") |
| port = address[1] |
| hosts.append( |
| ResolveResult( |
| hostname=host, |
| host=resolved_host, |
| port=port, |
| family=family, |
| proto=0, |
| flags=_NUMERIC_SOCKET_FLAGS, |
| ) |
| ) |
|
|
| if not hosts: |
| raise OSError(None, "DNS lookup failed") |
|
|
| return hosts |
|
|
| async def _resolve_with_query( |
| self, host: str, port: int = 0, family: int = socket.AF_INET |
| ) -> List[Dict[str, Any]]: |
| qtype: Final = "AAAA" if family == socket.AF_INET6 else "A" |
|
|
| try: |
| resp = await self._resolver.query(host, qtype) |
| except aiodns.error.DNSError as exc: |
| msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" |
| raise OSError(None, msg) from exc |
|
|
| hosts = [] |
| for rr in resp: |
| hosts.append( |
| { |
| "hostname": host, |
| "host": rr.host, |
| "port": port, |
| "family": family, |
| "proto": 0, |
| "flags": socket.AI_NUMERICHOST, |
| } |
| ) |
|
|
| if not hosts: |
| raise OSError(None, "DNS lookup failed") |
|
|
| return hosts |
|
|
| async def close(self) -> None: |
| if self._manager: |
| |
| self._manager.release_resolver(self, self._loop) |
| self._manager = None |
| self._resolver = None |
| return |
| |
| if self._resolver is not None: |
| self._resolver.cancel() |
| self._resolver = None |
|
|
|
|
| class _DNSResolverManager: |
| """Manager for aiodns.DNSResolver objects. |
| |
| This class manages shared aiodns.DNSResolver instances |
| with no custom arguments across different event loops. |
| """ |
|
|
| _instance: Optional["_DNSResolverManager"] = None |
|
|
| def __new__(cls) -> "_DNSResolverManager": |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| cls._instance._init() |
| return cls._instance |
|
|
| def _init(self) -> None: |
| |
| self._loop_data: weakref.WeakKeyDictionary[ |
| asyncio.AbstractEventLoop, |
| tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], |
| ] = weakref.WeakKeyDictionary() |
|
|
| def get_resolver( |
| self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| ) -> "aiodns.DNSResolver": |
| """Get or create the shared aiodns.DNSResolver instance for a specific event loop. |
| |
| Args: |
| client: The AsyncResolver instance requesting the resolver. |
| This is required to track resolver usage. |
| loop: The event loop to use for the resolver. |
| """ |
| |
| if loop not in self._loop_data: |
| resolver = aiodns.DNSResolver(loop=loop) |
| client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() |
| self._loop_data[loop] = (resolver, client_set) |
| else: |
| |
| resolver, client_set = self._loop_data[loop] |
|
|
| |
| client_set.add(client) |
| return resolver |
|
|
| def release_resolver( |
| self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop |
| ) -> None: |
| """Release the resolver for an AsyncResolver client when it's closed. |
| |
| Args: |
| client: The AsyncResolver instance to release. |
| loop: The event loop the resolver was using. |
| """ |
| |
| current_loop_data = self._loop_data.get(loop) |
| if current_loop_data is None: |
| return |
| resolver, client_set = current_loop_data |
| client_set.discard(client) |
| |
| if not client_set: |
| if resolver is not None: |
| resolver.cancel() |
| del self._loop_data[loop] |
|
|
|
|
| _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] |
| DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver |
|
|