File size: 5,272 Bytes
9d4c5ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import asyncio
import ipaddress
from pathlib import Path
import socket
import os
from functools import lru_cache, wraps
from typing import Any, Awaitable, Callable, Coroutine, Literal, T, Tuple
import httpx
def get_version():
version_file = Path(__file__).parent / 'version.txt'
with open(version_file, 'r') as f:
return f.read().strip()
__version__ = get_version()
def is_public_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return not (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
or ip_obj.is_reserved
)
except ValueError:
return False
def matches_domain_whitelist(hostname: str, domain_whitelist: list[str]) -> bool:
if not hostname or not domain_whitelist:
return False
hostname = hostname.lower()
for domain in domain_whitelist:
if not domain:
continue
domain = domain.lower()
if domain.startswith('*.'):
suffix = domain[2:]
if hostname == suffix or hostname.endswith('.' + suffix):
return True
elif hostname == domain:
return True
return False
def lru_cache_async(maxsize: int = 256):
def decorator(
async_func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Awaitable[T]]:
@lru_cache(maxsize=maxsize)
@wraps(async_func)
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return asyncio.create_task(async_func(*args, **kwargs))
return wrapper
return decorator
@lru_cache_async()
async def async_resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)
ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []
async def async_validate_url(hostname: str) -> str:
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e
for family, _, _, _, sockaddr in addrinfo:
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address
for ip_address in await async_resolve_hostname_google(hostname):
if is_public_ip(ip_address):
return ip_address
raise ValueError(f"Hostname {hostname} failed validation")
class AsyncSecureTransport(httpx.AsyncHTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()
async def handle_async_request(
self,
request: httpx.Request
) -> Tuple[int, bytes, bytes, httpx.Headers]:
original_url = request.url
original_host = original_url.host
new_url = original_url.copy_with(host=self.verified_ip)
request.url = new_url
request.headers['Host'] = original_host
request.extensions = {"sni_hostname": original_host}
return await super().handle_async_request(request)
async def get(
url: str,
domain_whitelist: list[str] | None = None,
_transport: httpx.AsyncBaseTransport | Literal[False] | None = None,
**kwargs,
) -> httpx.Response:
"""
This is the main function that should be used to make async HTTP GET requests.
It will automatically use a secure transport for non-whitelisted domains, unless
a proxy is set in the environment variables (HTTP_PROXY, HTTPS_PROXY, http_proxy, https_proxy).
Parameters:
- url (str): The URL to make a GET request to.
- domain_whitelist (list[str] | None): A list of domains to whitelist, which will not use a secure transport. Supports wildcard subdomains with "*.domain.com" format (asterisk must be at the beginning).
- _transport (httpx.AsyncBaseTransport | Literal[False] | None): A custom transport to use for the request. Takes precedence over domain_whitelist. Set to False to use no transport.
- **kwargs: Additional keyword arguments to pass to the httpx.AsyncClient.get() function.
"""
parsed_url = httpx.URL(url)
hostname = parsed_url.host
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
if domain_whitelist is None:
domain_whitelist = []
if _transport:
transport = _transport
elif _transport is False or matches_domain_whitelist(hostname, domain_whitelist):
transport = None
else:
verified_ip = await async_validate_url(hostname)
transport = AsyncSecureTransport(verified_ip)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, follow_redirects=False, **kwargs)
|