| import asyncio |
| import ipaddress |
| from pathlib import Path |
| import socket |
| import os |
| from functools import lru_cache, wraps |
| from typing import Any, Awaitable, Callable, Coroutine, Literal, T, Tuple |
|
|
| import httpx |
|
|
| def get_version(): |
| version_file = Path(__file__).parent / 'version.txt' |
| with open(version_file, 'r') as f: |
| return f.read().strip() |
|
|
| __version__ = get_version() |
|
|
| def is_public_ip(ip: str) -> bool: |
| try: |
| ip_obj = ipaddress.ip_address(ip) |
| return not ( |
| ip_obj.is_private |
| or ip_obj.is_loopback |
| or ip_obj.is_link_local |
| or ip_obj.is_multicast |
| or ip_obj.is_reserved |
| ) |
| except ValueError: |
| return False |
|
|
|
|
| def matches_domain_whitelist(hostname: str, domain_whitelist: list[str]) -> bool: |
| if not hostname or not domain_whitelist: |
| return False |
| hostname = hostname.lower() |
| for domain in domain_whitelist: |
| if not domain: |
| continue |
| domain = domain.lower() |
| if domain.startswith('*.'): |
| suffix = domain[2:] |
| if hostname == suffix or hostname.endswith('.' + suffix): |
| return True |
| elif hostname == domain: |
| return True |
| return False |
|
|
|
|
| def lru_cache_async(maxsize: int = 256): |
| def decorator( |
| async_func: Callable[..., Coroutine[Any, Any, T]], |
| ) -> Callable[..., Awaitable[T]]: |
| @lru_cache(maxsize=maxsize) |
| @wraps(async_func) |
| def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]: |
| return asyncio.create_task(async_func(*args, **kwargs)) |
|
|
| return wrapper |
|
|
| return decorator |
|
|
|
|
| @lru_cache_async() |
| async def async_resolve_hostname_google(hostname: str) -> list[str]: |
| async with httpx.AsyncClient() as client: |
| try: |
| response_v4 = await client.get( |
| f"https://dns.google/resolve?name={hostname}&type=A" |
| ) |
| response_v6 = await client.get( |
| f"https://dns.google/resolve?name={hostname}&type=AAAA" |
| ) |
|
|
| ips = [] |
| for response in [response_v4.json(), response_v6.json()]: |
| ips.extend([answer["data"] for answer in response.get("Answer", [])]) |
| return ips |
| except Exception: |
| return [] |
|
|
|
|
| async def async_validate_url(hostname: str) -> str: |
| try: |
| loop = asyncio.get_event_loop() |
| addrinfo = await loop.getaddrinfo(hostname, None) |
| except socket.gaierror as e: |
| raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e |
|
|
| for family, _, _, _, sockaddr in addrinfo: |
| ip_address = sockaddr[0] |
| if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address): |
| return ip_address |
|
|
| for ip_address in await async_resolve_hostname_google(hostname): |
| if is_public_ip(ip_address): |
| return ip_address |
|
|
| raise ValueError(f"Hostname {hostname} failed validation") |
|
|
|
|
| class AsyncSecureTransport(httpx.AsyncHTTPTransport): |
| def __init__(self, verified_ip: str): |
| self.verified_ip = verified_ip |
| super().__init__() |
|
|
| async def handle_async_request( |
| self, |
| request: httpx.Request |
| ) -> Tuple[int, bytes, bytes, httpx.Headers]: |
| original_url = request.url |
| original_host = original_url.host |
| new_url = original_url.copy_with(host=self.verified_ip) |
| request.url = new_url |
| request.headers['Host'] = original_host |
| request.extensions = {"sni_hostname": original_host} |
| return await super().handle_async_request(request) |
|
|
| async def get( |
| url: str, |
| domain_whitelist: list[str] | None = None, |
| _transport: httpx.AsyncBaseTransport | Literal[False] | None = None, |
| **kwargs, |
| ) -> httpx.Response: |
| """ |
| This is the main function that should be used to make async HTTP GET requests. |
| It will automatically use a secure transport for non-whitelisted domains, unless |
| a proxy is set in the environment variables (HTTP_PROXY, HTTPS_PROXY, http_proxy, https_proxy). |
| |
| Parameters: |
| - url (str): The URL to make a GET request to. |
| - domain_whitelist (list[str] | None): A list of domains to whitelist, which will not use a secure transport. Supports wildcard subdomains with "*.domain.com" format (asterisk must be at the beginning). |
| - _transport (httpx.AsyncBaseTransport | Literal[False] | None): A custom transport to use for the request. Takes precedence over domain_whitelist. Set to False to use no transport. |
| - **kwargs: Additional keyword arguments to pass to the httpx.AsyncClient.get() function. |
| """ |
| parsed_url = httpx.URL(url) |
| hostname = parsed_url.host |
| if not hostname: |
| raise ValueError(f"URL {url} does not have a valid hostname") |
| if domain_whitelist is None: |
| domain_whitelist = [] |
|
|
| if _transport: |
| transport = _transport |
| elif _transport is False or matches_domain_whitelist(hostname, domain_whitelist): |
| transport = None |
| else: |
| verified_ip = await async_validate_url(hostname) |
| transport = AsyncSecureTransport(verified_ip) |
|
|
| async with httpx.AsyncClient(transport=transport) as client: |
| return await client.get(url, follow_redirects=False, **kwargs) |
|
|