| | """Various helper functions""" |
| |
|
| | import asyncio |
| | import base64 |
| | import binascii |
| | import contextlib |
| | import datetime |
| | import enum |
| | import functools |
| | import inspect |
| | import netrc |
| | import os |
| | import platform |
| | import re |
| | import sys |
| | import time |
| | import weakref |
| | from collections import namedtuple |
| | from contextlib import suppress |
| | from email.message import EmailMessage |
| | from email.parser import HeaderParser |
| | from email.policy import HTTP |
| | from email.utils import parsedate |
| | from math import ceil |
| | from pathlib import Path |
| | from types import MappingProxyType, TracebackType |
| | from typing import ( |
| | Any, |
| | Callable, |
| | ContextManager, |
| | Dict, |
| | Generator, |
| | Generic, |
| | Iterable, |
| | Iterator, |
| | List, |
| | Mapping, |
| | Optional, |
| | Protocol, |
| | Tuple, |
| | Type, |
| | TypeVar, |
| | Union, |
| | get_args, |
| | overload, |
| | ) |
| | from urllib.parse import quote |
| | from urllib.request import getproxies, proxy_bypass |
| |
|
| | import attr |
| | from multidict import MultiDict, MultiDictProxy, MultiMapping |
| | from propcache.api import under_cached_property as reify |
| | from yarl import URL |
| |
|
| | from . import hdrs |
| | from .log import client_logger |
| |
|
| | if sys.version_info >= (3, 11): |
| | import asyncio as async_timeout |
| | else: |
| | import async_timeout |
| |
|
| | __all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify") |
| |
|
| | IS_MACOS = platform.system() == "Darwin" |
| | IS_WINDOWS = platform.system() == "Windows" |
| |
|
| | PY_310 = sys.version_info >= (3, 10) |
| | PY_311 = sys.version_info >= (3, 11) |
| |
|
| |
|
| | _T = TypeVar("_T") |
| | _S = TypeVar("_S") |
| |
|
| | _SENTINEL = enum.Enum("_SENTINEL", "sentinel") |
| | sentinel = _SENTINEL.sentinel |
| |
|
| | NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) |
| |
|
| | |
| | EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200))) |
| | |
| | |
| | EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL |
| |
|
| | DEBUG = sys.flags.dev_mode or ( |
| | not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) |
| | ) |
| |
|
| |
|
| | CHAR = {chr(i) for i in range(0, 128)} |
| | CTL = {chr(i) for i in range(0, 32)} | { |
| | chr(127), |
| | } |
| | SEPARATORS = { |
| | "(", |
| | ")", |
| | "<", |
| | ">", |
| | "@", |
| | ",", |
| | ";", |
| | ":", |
| | "\\", |
| | '"', |
| | "/", |
| | "[", |
| | "]", |
| | "?", |
| | "=", |
| | "{", |
| | "}", |
| | " ", |
| | chr(9), |
| | } |
| | TOKEN = CHAR ^ CTL ^ SEPARATORS |
| |
|
| |
|
| | class noop: |
| | def __await__(self) -> Generator[None, None, None]: |
| | yield |
| |
|
| |
|
| | class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): |
| | """Http basic authentication helper.""" |
| |
|
| | def __new__( |
| | cls, login: str, password: str = "", encoding: str = "latin1" |
| | ) -> "BasicAuth": |
| | if login is None: |
| | raise ValueError("None is not allowed as login value") |
| |
|
| | if password is None: |
| | raise ValueError("None is not allowed as password value") |
| |
|
| | if ":" in login: |
| | raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') |
| |
|
| | return super().__new__(cls, login, password, encoding) |
| |
|
| | @classmethod |
| | def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": |
| | """Create a BasicAuth object from an Authorization HTTP header.""" |
| | try: |
| | auth_type, encoded_credentials = auth_header.split(" ", 1) |
| | except ValueError: |
| | raise ValueError("Could not parse authorization header.") |
| |
|
| | if auth_type.lower() != "basic": |
| | raise ValueError("Unknown authorization method %s" % auth_type) |
| |
|
| | try: |
| | decoded = base64.b64decode( |
| | encoded_credentials.encode("ascii"), validate=True |
| | ).decode(encoding) |
| | except binascii.Error: |
| | raise ValueError("Invalid base64 encoding.") |
| |
|
| | try: |
| | |
| | |
| | |
| | |
| | username, password = decoded.split(":", 1) |
| | except ValueError: |
| | raise ValueError("Invalid credentials.") |
| |
|
| | return cls(username, password, encoding=encoding) |
| |
|
| | @classmethod |
| | def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: |
| | """Create BasicAuth from url.""" |
| | if not isinstance(url, URL): |
| | raise TypeError("url should be yarl.URL instance") |
| | |
| | |
| | if url.raw_user is None and url.raw_password is None: |
| | return None |
| | return cls(url.user or "", url.password or "", encoding=encoding) |
| |
|
| | def encode(self) -> str: |
| | """Encode credentials.""" |
| | creds = (f"{self.login}:{self.password}").encode(self.encoding) |
| | return "Basic %s" % base64.b64encode(creds).decode(self.encoding) |
| |
|
| |
|
| | def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: |
| | """Remove user and password from URL if present and return BasicAuth object.""" |
| | |
| | |
| | if url.raw_user is None and url.raw_password is None: |
| | return url, None |
| | return url.with_user(None), BasicAuth(url.user or "", url.password or "") |
| |
|
| |
|
| | def netrc_from_env() -> Optional[netrc.netrc]: |
| | """Load netrc from file. |
| | |
| | Attempt to load it from the path specified by the env-var |
| | NETRC or in the default location in the user's home directory. |
| | |
| | Returns None if it couldn't be found or fails to parse. |
| | """ |
| | netrc_env = os.environ.get("NETRC") |
| |
|
| | if netrc_env is not None: |
| | netrc_path = Path(netrc_env) |
| | else: |
| | try: |
| | home_dir = Path.home() |
| | except RuntimeError as e: |
| | |
| | client_logger.debug( |
| | "Could not resolve home directory when " |
| | "trying to look for .netrc file: %s", |
| | e, |
| | ) |
| | return None |
| |
|
| | netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc") |
| |
|
| | try: |
| | return netrc.netrc(str(netrc_path)) |
| | except netrc.NetrcParseError as e: |
| | client_logger.warning("Could not parse .netrc file: %s", e) |
| | except OSError as e: |
| | netrc_exists = False |
| | with contextlib.suppress(OSError): |
| | netrc_exists = netrc_path.is_file() |
| | |
| | if netrc_env or netrc_exists: |
| | |
| | |
| | client_logger.warning("Could not read .netrc file: %s", e) |
| |
|
| | return None |
| |
|
| |
|
| | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | class ProxyInfo: |
| | proxy: URL |
| | proxy_auth: Optional[BasicAuth] |
| |
|
| |
|
| | def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: |
| | """ |
| | Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. |
| | |
| | :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no |
| | entry is found for the ``host``. |
| | """ |
| | if netrc_obj is None: |
| | raise LookupError("No .netrc file found") |
| | auth_from_netrc = netrc_obj.authenticators(host) |
| |
|
| | if auth_from_netrc is None: |
| | raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") |
| | login, account, password = auth_from_netrc |
| |
|
| | |
| | |
| | |
| | |
| | username = login if (login or account is None) else account |
| |
|
| | |
| | |
| | if password is None: |
| | password = "" |
| |
|
| | return BasicAuth(username, password) |
| |
|
| |
|
| | def proxies_from_env() -> Dict[str, ProxyInfo]: |
| | proxy_urls = { |
| | k: URL(v) |
| | for k, v in getproxies().items() |
| | if k in ("http", "https", "ws", "wss") |
| | } |
| | netrc_obj = netrc_from_env() |
| | stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} |
| | ret = {} |
| | for proto, val in stripped.items(): |
| | proxy, auth = val |
| | if proxy.scheme in ("https", "wss"): |
| | client_logger.warning( |
| | "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy |
| | ) |
| | continue |
| | if netrc_obj and auth is None: |
| | if proxy.host is not None: |
| | try: |
| | auth = basicauth_from_netrc(netrc_obj, proxy.host) |
| | except LookupError: |
| | auth = None |
| | ret[proto] = ProxyInfo(proxy, auth) |
| | return ret |
| |
|
| |
|
| | def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: |
| | """Get a permitted proxy for the given URL from the env.""" |
| | if url.host is not None and proxy_bypass(url.host): |
| | raise LookupError(f"Proxying is disallowed for `{url.host!r}`") |
| |
|
| | proxies_in_env = proxies_from_env() |
| | try: |
| | proxy_info = proxies_in_env[url.scheme] |
| | except KeyError: |
| | raise LookupError(f"No proxies found for `{url!s}` in the env") |
| | else: |
| | return proxy_info.proxy, proxy_info.proxy_auth |
| |
|
| |
|
| | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | class MimeType: |
| | type: str |
| | subtype: str |
| | suffix: str |
| | parameters: "MultiDictProxy[str]" |
| |
|
| |
|
| | @functools.lru_cache(maxsize=56) |
| | def parse_mimetype(mimetype: str) -> MimeType: |
| | """Parses a MIME type into its components. |
| | |
| | mimetype is a MIME type string. |
| | |
| | Returns a MimeType object. |
| | |
| | Example: |
| | |
| | >>> parse_mimetype('text/html; charset=utf-8') |
| | MimeType(type='text', subtype='html', suffix='', |
| | parameters={'charset': 'utf-8'}) |
| | |
| | """ |
| | if not mimetype: |
| | return MimeType( |
| | type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) |
| | ) |
| |
|
| | parts = mimetype.split(";") |
| | params: MultiDict[str] = MultiDict() |
| | for item in parts[1:]: |
| | if not item: |
| | continue |
| | key, _, value = item.partition("=") |
| | params.add(key.lower().strip(), value.strip(' "')) |
| |
|
| | fulltype = parts[0].strip().lower() |
| | if fulltype == "*": |
| | fulltype = "*/*" |
| |
|
| | mtype, _, stype = fulltype.partition("/") |
| | stype, _, suffix = stype.partition("+") |
| |
|
| | return MimeType( |
| | type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) |
| | ) |
| |
|
| |
|
| | class EnsureOctetStream(EmailMessage): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | |
| | self.set_default_type("application/octet-stream") |
| |
|
| | def get_content_type(self) -> str: |
| | """Re-implementation from Message |
| | |
| | Returns application/octet-stream in place of plain/text when |
| | value is wrong. |
| | |
| | The way this class is used guarantees that content-type will |
| | be present so simplify the checks wrt to the base implementation. |
| | """ |
| | value = self.get("content-type", "").lower() |
| |
|
| | |
| | ctype, _, _ = value.partition(";") |
| | ctype = ctype.strip() |
| | if ctype.count("/") != 1: |
| | return self.get_default_type() |
| | return ctype |
| |
|
| |
|
| | @functools.lru_cache(maxsize=56) |
| | def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]: |
| | """Parse Content-Type header. |
| | |
| | Returns a tuple of the parsed content type and a |
| | MappingProxyType of parameters. The default returned value |
| | is `application/octet-stream` |
| | """ |
| | msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}") |
| | content_type = msg.get_content_type() |
| | params = msg.get_params(()) |
| | content_dict = dict(params[1:]) |
| | return content_type, MappingProxyType(content_dict) |
| |
|
| |
|
| | def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: |
| | name = getattr(obj, "name", None) |
| | if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": |
| | return Path(name).name |
| | return default |
| |
|
| |
|
| | not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") |
| | QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} |
| |
|
| |
|
| | def quoted_string(content: str) -> str: |
| | """Return 7-bit content as quoted-string. |
| | |
| | Format content into a quoted-string as defined in RFC5322 for |
| | Internet Message Format. Notice that this is not the 8-bit HTTP |
| | format, but the 7-bit email format. Content must be in usascii or |
| | a ValueError is raised. |
| | """ |
| | if not (QCONTENT > set(content)): |
| | raise ValueError(f"bad content for quoted-string {content!r}") |
| | return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) |
| |
|
| |
|
| | def content_disposition_header( |
| | disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str |
| | ) -> str: |
| | """Sets ``Content-Disposition`` header for MIME. |
| | |
| | This is the MIME payload Content-Disposition header from RFC 2183 |
| | and RFC 7579 section 4.2, not the HTTP Content-Disposition from |
| | RFC 6266. |
| | |
| | disptype is a disposition type: inline, attachment, form-data. |
| | Should be valid extension token (see RFC 2183) |
| | |
| | quote_fields performs value quoting to 7-bit MIME headers |
| | according to RFC 7578. Set to quote_fields to False if recipient |
| | can take 8-bit file names and field values. |
| | |
| | _charset specifies the charset to use when quote_fields is True. |
| | |
| | params is a dict with disposition params. |
| | """ |
| | if not disptype or not (TOKEN > set(disptype)): |
| | raise ValueError(f"bad content disposition type {disptype!r}") |
| |
|
| | value = disptype |
| | if params: |
| | lparams = [] |
| | for key, val in params.items(): |
| | if not key or not (TOKEN > set(key)): |
| | raise ValueError(f"bad content disposition parameter {key!r}={val!r}") |
| | if quote_fields: |
| | if key.lower() == "filename": |
| | qval = quote(val, "", encoding=_charset) |
| | lparams.append((key, '"%s"' % qval)) |
| | else: |
| | try: |
| | qval = quoted_string(val) |
| | except ValueError: |
| | qval = "".join( |
| | (_charset, "''", quote(val, "", encoding=_charset)) |
| | ) |
| | lparams.append((key + "*", qval)) |
| | else: |
| | lparams.append((key, '"%s"' % qval)) |
| | else: |
| | qval = val.replace("\\", "\\\\").replace('"', '\\"') |
| | lparams.append((key, '"%s"' % qval)) |
| | sparams = "; ".join("=".join(pair) for pair in lparams) |
| | value = "; ".join((value, sparams)) |
| | return value |
| |
|
| |
|
| | def is_ip_address(host: Optional[str]) -> bool: |
| | """Check if host looks like an IP Address. |
| | |
| | This check is only meant as a heuristic to ensure that |
| | a host is not a domain name. |
| | """ |
| | if not host: |
| | return False |
| | |
| | |
| | return ":" in host or host.replace(".", "").isdigit() |
| |
|
| |
|
| | _cached_current_datetime: Optional[int] = None |
| | _cached_formatted_datetime = "" |
| |
|
| |
|
| | def rfc822_formatted_time() -> str: |
| | global _cached_current_datetime |
| | global _cached_formatted_datetime |
| |
|
| | now = int(time.time()) |
| | if now != _cached_current_datetime: |
| | |
| | |
| | |
| | _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") |
| | _monthname = ( |
| | "", |
| | "Jan", |
| | "Feb", |
| | "Mar", |
| | "Apr", |
| | "May", |
| | "Jun", |
| | "Jul", |
| | "Aug", |
| | "Sep", |
| | "Oct", |
| | "Nov", |
| | "Dec", |
| | ) |
| |
|
| | year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) |
| | _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( |
| | _weekdayname[wd], |
| | day, |
| | _monthname[month], |
| | year, |
| | hh, |
| | mm, |
| | ss, |
| | ) |
| | _cached_current_datetime = now |
| | return _cached_formatted_datetime |
| |
|
| |
|
| | def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: |
| | ref, name = info |
| | ob = ref() |
| | if ob is not None: |
| | with suppress(Exception): |
| | getattr(ob, name)() |
| |
|
| |
|
| | def weakref_handle( |
| | ob: object, |
| | name: str, |
| | timeout: float, |
| | loop: asyncio.AbstractEventLoop, |
| | timeout_ceil_threshold: float = 5, |
| | ) -> Optional[asyncio.TimerHandle]: |
| | if timeout is not None and timeout > 0: |
| | when = loop.time() + timeout |
| | if timeout >= timeout_ceil_threshold: |
| | when = ceil(when) |
| |
|
| | return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) |
| | return None |
| |
|
| |
|
| | def call_later( |
| | cb: Callable[[], Any], |
| | timeout: float, |
| | loop: asyncio.AbstractEventLoop, |
| | timeout_ceil_threshold: float = 5, |
| | ) -> Optional[asyncio.TimerHandle]: |
| | if timeout is None or timeout <= 0: |
| | return None |
| | now = loop.time() |
| | when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) |
| | return loop.call_at(when, cb) |
| |
|
| |
|
| | def calculate_timeout_when( |
| | loop_time: float, |
| | timeout: float, |
| | timeout_ceiling_threshold: float, |
| | ) -> float: |
| | """Calculate when to execute a timeout.""" |
| | when = loop_time + timeout |
| | if timeout > timeout_ceiling_threshold: |
| | return ceil(when) |
| | return when |
| |
|
| |
|
| | class TimeoutHandle: |
| | """Timeout handle""" |
| |
|
| | __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks") |
| |
|
| | def __init__( |
| | self, |
| | loop: asyncio.AbstractEventLoop, |
| | timeout: Optional[float], |
| | ceil_threshold: float = 5, |
| | ) -> None: |
| | self._timeout = timeout |
| | self._loop = loop |
| | self._ceil_threshold = ceil_threshold |
| | self._callbacks: List[ |
| | Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] |
| | ] = [] |
| |
|
| | def register( |
| | self, callback: Callable[..., None], *args: Any, **kwargs: Any |
| | ) -> None: |
| | self._callbacks.append((callback, args, kwargs)) |
| |
|
| | def close(self) -> None: |
| | self._callbacks.clear() |
| |
|
| | def start(self) -> Optional[asyncio.TimerHandle]: |
| | timeout = self._timeout |
| | if timeout is not None and timeout > 0: |
| | when = self._loop.time() + timeout |
| | if timeout >= self._ceil_threshold: |
| | when = ceil(when) |
| | return self._loop.call_at(when, self.__call__) |
| | else: |
| | return None |
| |
|
| | def timer(self) -> "BaseTimerContext": |
| | if self._timeout is not None and self._timeout > 0: |
| | timer = TimerContext(self._loop) |
| | self.register(timer.timeout) |
| | return timer |
| | else: |
| | return TimerNoop() |
| |
|
| | def __call__(self) -> None: |
| | for cb, args, kwargs in self._callbacks: |
| | with suppress(Exception): |
| | cb(*args, **kwargs) |
| |
|
| | self._callbacks.clear() |
| |
|
| |
|
| | class BaseTimerContext(ContextManager["BaseTimerContext"]): |
| |
|
| | __slots__ = () |
| |
|
| | def assert_timeout(self) -> None: |
| | """Raise TimeoutError if timeout has been exceeded.""" |
| |
|
| |
|
| | class TimerNoop(BaseTimerContext): |
| |
|
| | __slots__ = () |
| |
|
| | def __enter__(self) -> BaseTimerContext: |
| | return self |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: Optional[Type[BaseException]], |
| | exc_val: Optional[BaseException], |
| | exc_tb: Optional[TracebackType], |
| | ) -> None: |
| | return |
| |
|
| |
|
| | class TimerContext(BaseTimerContext): |
| | """Low resolution timeout context manager""" |
| |
|
| | __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling") |
| |
|
| | def __init__(self, loop: asyncio.AbstractEventLoop) -> None: |
| | self._loop = loop |
| | self._tasks: List[asyncio.Task[Any]] = [] |
| | self._cancelled = False |
| | self._cancelling = 0 |
| |
|
| | def assert_timeout(self) -> None: |
| | """Raise TimeoutError if timer has already been cancelled.""" |
| | if self._cancelled: |
| | raise asyncio.TimeoutError from None |
| |
|
| | def __enter__(self) -> BaseTimerContext: |
| | task = asyncio.current_task(loop=self._loop) |
| | if task is None: |
| | raise RuntimeError("Timeout context manager should be used inside a task") |
| |
|
| | if sys.version_info >= (3, 11): |
| | |
| | |
| | |
| | self._cancelling = task.cancelling() |
| |
|
| | if self._cancelled: |
| | raise asyncio.TimeoutError from None |
| |
|
| | self._tasks.append(task) |
| | return self |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: Optional[Type[BaseException]], |
| | exc_val: Optional[BaseException], |
| | exc_tb: Optional[TracebackType], |
| | ) -> Optional[bool]: |
| | enter_task: Optional[asyncio.Task[Any]] = None |
| | if self._tasks: |
| | enter_task = self._tasks.pop() |
| |
|
| | if exc_type is asyncio.CancelledError and self._cancelled: |
| | assert enter_task is not None |
| | |
| | |
| | |
| | if sys.version_info >= (3, 11): |
| | |
| | |
| | |
| | if enter_task.uncancel() > self._cancelling: |
| | return None |
| | raise asyncio.TimeoutError from exc_val |
| | return None |
| |
|
| | def timeout(self) -> None: |
| | if not self._cancelled: |
| | for task in set(self._tasks): |
| | task.cancel() |
| |
|
| | self._cancelled = True |
| |
|
| |
|
| | def ceil_timeout( |
| | delay: Optional[float], ceil_threshold: float = 5 |
| | ) -> async_timeout.Timeout: |
| | if delay is None or delay <= 0: |
| | return async_timeout.timeout(None) |
| |
|
| | loop = asyncio.get_running_loop() |
| | now = loop.time() |
| | when = now + delay |
| | if delay > ceil_threshold: |
| | when = ceil(when) |
| | return async_timeout.timeout_at(when) |
| |
|
| |
|
| | class HeadersMixin: |
| | """Mixin for handling headers.""" |
| |
|
| | ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) |
| |
|
| | _headers: MultiMapping[str] |
| | _content_type: Optional[str] = None |
| | _content_dict: Optional[Dict[str, str]] = None |
| | _stored_content_type: Union[str, None, _SENTINEL] = sentinel |
| |
|
| | def _parse_content_type(self, raw: Optional[str]) -> None: |
| | self._stored_content_type = raw |
| | if raw is None: |
| | |
| | self._content_type = "application/octet-stream" |
| | self._content_dict = {} |
| | else: |
| | content_type, content_mapping_proxy = parse_content_type(raw) |
| | self._content_type = content_type |
| | |
| | self._content_dict = content_mapping_proxy.copy() |
| |
|
| | @property |
| | def content_type(self) -> str: |
| | """The value of content part for Content-Type HTTP header.""" |
| | raw = self._headers.get(hdrs.CONTENT_TYPE) |
| | if self._stored_content_type != raw: |
| | self._parse_content_type(raw) |
| | assert self._content_type is not None |
| | return self._content_type |
| |
|
| | @property |
| | def charset(self) -> Optional[str]: |
| | """The value of charset part for Content-Type HTTP header.""" |
| | raw = self._headers.get(hdrs.CONTENT_TYPE) |
| | if self._stored_content_type != raw: |
| | self._parse_content_type(raw) |
| | assert self._content_dict is not None |
| | return self._content_dict.get("charset") |
| |
|
| | @property |
| | def content_length(self) -> Optional[int]: |
| | """The value of Content-Length HTTP header.""" |
| | content_length = self._headers.get(hdrs.CONTENT_LENGTH) |
| | return None if content_length is None else int(content_length) |
| |
|
| |
|
| | def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: |
| | if not fut.done(): |
| | fut.set_result(result) |
| |
|
| |
|
| | _EXC_SENTINEL = BaseException() |
| |
|
| |
|
| | class ErrorableProtocol(Protocol): |
| | def set_exception( |
| | self, |
| | exc: BaseException, |
| | exc_cause: BaseException = ..., |
| | ) -> None: ... |
| |
|
| |
|
| | def set_exception( |
| | fut: "asyncio.Future[_T] | ErrorableProtocol", |
| | exc: BaseException, |
| | exc_cause: BaseException = _EXC_SENTINEL, |
| | ) -> None: |
| | """Set future exception. |
| | |
| | If the future is marked as complete, this function is a no-op. |
| | |
| | :param exc_cause: An exception that is a direct cause of ``exc``. |
| | Only set if provided. |
| | """ |
| | if asyncio.isfuture(fut) and fut.done(): |
| | return |
| |
|
| | exc_is_sentinel = exc_cause is _EXC_SENTINEL |
| | exc_causes_itself = exc is exc_cause |
| | if not exc_is_sentinel and not exc_causes_itself: |
| | exc.__cause__ = exc_cause |
| |
|
| | fut.set_exception(exc) |
| |
|
| |
|
| | @functools.total_ordering |
| | class AppKey(Generic[_T]): |
| | """Keys for static typing support in Application.""" |
| |
|
| | __slots__ = ("_name", "_t", "__orig_class__") |
| |
|
| | |
| | |
| | |
| | __orig_class__: Type[object] |
| |
|
| | def __init__(self, name: str, t: Optional[Type[_T]] = None): |
| | |
| | frame = inspect.currentframe() |
| | while frame: |
| | if frame.f_code.co_name == "<module>": |
| | module: str = frame.f_globals["__name__"] |
| | break |
| | frame = frame.f_back |
| |
|
| | self._name = module + "." + name |
| | self._t = t |
| |
|
| | def __lt__(self, other: object) -> bool: |
| | if isinstance(other, AppKey): |
| | return self._name < other._name |
| | return True |
| |
|
| | def __repr__(self) -> str: |
| | t = self._t |
| | if t is None: |
| | with suppress(AttributeError): |
| | |
| | t = get_args(self.__orig_class__)[0] |
| |
|
| | if t is None: |
| | t_repr = "<<Unknown>>" |
| | elif isinstance(t, type): |
| | if t.__module__ == "builtins": |
| | t_repr = t.__qualname__ |
| | else: |
| | t_repr = f"{t.__module__}.{t.__qualname__}" |
| | else: |
| | t_repr = repr(t) |
| | return f"<AppKey({self._name}, type={t_repr})>" |
| |
|
| |
|
| | class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): |
| | __slots__ = ("_maps",) |
| |
|
| | def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None: |
| | self._maps = tuple(maps) |
| |
|
| | def __init_subclass__(cls) -> None: |
| | raise TypeError( |
| | "Inheritance class {} from ChainMapProxy " |
| | "is forbidden".format(cls.__name__) |
| | ) |
| |
|
| | @overload |
| | def __getitem__(self, key: AppKey[_T]) -> _T: ... |
| |
|
| | @overload |
| | def __getitem__(self, key: str) -> Any: ... |
| |
|
| | def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: |
| | for mapping in self._maps: |
| | try: |
| | return mapping[key] |
| | except KeyError: |
| | pass |
| | raise KeyError(key) |
| |
|
| | @overload |
| | def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ... |
| |
|
| | @overload |
| | def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... |
| |
|
| | @overload |
| | def get(self, key: str, default: Any = ...) -> Any: ... |
| |
|
| | def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: |
| | try: |
| | return self[key] |
| | except KeyError: |
| | return default |
| |
|
| | def __len__(self) -> int: |
| | |
| | return len(set().union(*self._maps)) |
| |
|
| | def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: |
| | d: Dict[Union[str, AppKey[Any]], Any] = {} |
| | for mapping in reversed(self._maps): |
| | |
| | d.update(mapping) |
| | return iter(d) |
| |
|
| | def __contains__(self, key: object) -> bool: |
| | return any(key in m for m in self._maps) |
| |
|
| | def __bool__(self) -> bool: |
| | return any(self._maps) |
| |
|
| | def __repr__(self) -> str: |
| | content = ", ".join(map(repr, self._maps)) |
| | return f"ChainMapProxy({content})" |
| |
|
| |
|
| | |
| | _ETAGC = r"[!\x23-\x7E\x80-\xff]+" |
| | _ETAGC_RE = re.compile(_ETAGC) |
| | _QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' |
| | QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) |
| | LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") |
| |
|
| | ETAG_ANY = "*" |
| |
|
| |
|
| | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | class ETag: |
| | value: str |
| | is_weak: bool = False |
| |
|
| |
|
| | def validate_etag_value(value: str) -> None: |
| | if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): |
| | raise ValueError( |
| | f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" |
| | ) |
| |
|
| |
|
| | def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: |
| | """Process a date string, return a datetime object""" |
| | if date_str is not None: |
| | timetuple = parsedate(date_str) |
| | if timetuple is not None: |
| | with suppress(ValueError): |
| | return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) |
| | return None |
| |
|
| |
|
| | @functools.lru_cache |
| | def must_be_empty_body(method: str, code: int) -> bool: |
| | """Check if a request must return an empty body.""" |
| | return ( |
| | code in EMPTY_BODY_STATUS_CODES |
| | or method in EMPTY_BODY_METHODS |
| | or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL) |
| | ) |
| |
|
| |
|
| | def should_remove_content_length(method: str, code: int) -> bool: |
| | """Check if a Content-Length header should be removed. |
| | |
| | This should always be a subset of must_be_empty_body |
| | """ |
| | |
| | |
| | return code in EMPTY_BODY_STATUS_CODES or ( |
| | 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL |
| | ) |
| |
|