diff --git a/.gitattributes b/.gitattributes index 1c307b0c28278ec6d9aed3eca761354787b1a724..65584d1c6e10ad6b9324a0a08a2dc06e2c74315b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -205,3 +205,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/vllm/__pycache__/utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/grpc/_cython/cygrpc.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/wrapt/_wrappers.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/__pycache__/typing_extensions.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/__pycache__/pynvml.cpython-311.pyc b/.venv/lib/python3.11/site-packages/__pycache__/pynvml.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f914f8778f797fe3cb727105c8ad7ef87ce7b3e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/__pycache__/pynvml.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ec8fbd54c733bec6399caf5cd7da39d61df25426c6376b9e78be8d375f12722 +size 285515 diff --git a/.venv/lib/python3.11/site-packages/__pycache__/typing_extensions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/__pycache__/typing_extensions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3870c780c1fb2a6134bcbcf1bda3ccd908439e1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/__pycache__/typing_extensions.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4270657b146f00b5210a7cbf963ff4b514a08a0e7303eef5ba0a9e3a6c9a5e5b +size 151467 diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56dfc7987476103838b0cdc41400052f41b950bb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f193118a96311e9d8f0e4706e052914ef158053 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/gc_collector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/gc_collector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71b1e17b7e2154c503619bb3be6216a0fc9ed341 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/gc_collector.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8cab1348a56de7d7f0fa1083db1aa307bf6c437 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/multiprocess.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/multiprocess.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7d75e9f3870dddd53cd7a06ba9b7db9f68c7adc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/multiprocess.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/process_collector.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/process_collector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c58a156faa32845532245ba3d50c60185f4b9f28 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/process_collector.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/samples.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/samples.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77eeb5b821ad73f95201a7382721ad65037857e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/samples.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db393cece9589f8184d1dfead4d9e779d3745eee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__init__.py b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d20cce032a1a97e0b155a16c630ee7c354adb54 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__pycache__/graphite.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__pycache__/graphite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a3ef283db435e0e4d1c3099e4f070064d8dd4a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/__pycache__/graphite.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/bridge/graphite.py b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/graphite.py new file mode 100644 index 0000000000000000000000000000000000000000..8cadbedc53f7b45521af7fcc8c88e2fcfa7ec042 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/prometheus_client/bridge/graphite.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python + +import logging +import re +import socket +import threading +import time +from timeit import default_timer +from typing import Callable, Tuple + +from ..registry import CollectorRegistry, REGISTRY + +# Roughly, have to keep to what works as a file name. +# We also remove periods, so labels can be distinguished. + +_INVALID_GRAPHITE_CHARS = re.compile(r"[^a-zA-Z0-9_-]") + + +def _sanitize(s): + return _INVALID_GRAPHITE_CHARS.sub('_', s) + + +class _RegularPush(threading.Thread): + def __init__(self, pusher, interval, prefix): + super().__init__() + self._pusher = pusher + self._interval = interval + self._prefix = prefix + + def run(self): + wait_until = default_timer() + while True: + while True: + now = default_timer() + if now >= wait_until: + # May need to skip some pushes. + while wait_until < now: + wait_until += self._interval + break + # time.sleep can return early. + time.sleep(wait_until - now) + try: + self._pusher.push(prefix=self._prefix) + except OSError: + logging.exception("Push failed") + + +class GraphiteBridge: + def __init__(self, + address: Tuple[str, int], + registry: CollectorRegistry = REGISTRY, + timeout_seconds: float = 30, + _timer: Callable[[], float] = time.time, + tags: bool = False, + ): + self._address = address + self._registry = registry + self._tags = tags + self._timeout = timeout_seconds + self._timer = _timer + + def push(self, prefix: str = '') -> None: + now = int(self._timer()) + output = [] + + prefixstr = '' + if prefix: + prefixstr = prefix + '.' + + for metric in self._registry.collect(): + for s in metric.samples: + if s.labels: + if self._tags: + sep = ';' + fmt = '{0}={1}' + else: + sep = '.' + fmt = '{0}.{1}' + labelstr = sep + sep.join( + [fmt.format( + _sanitize(k), _sanitize(v)) + for k, v in sorted(s.labels.items())]) + else: + labelstr = '' + output.append(f'{prefixstr}{_sanitize(s.name)}{labelstr} {float(s.value)} {now}\n') + + conn = socket.create_connection(self._address, self._timeout) + conn.sendall(''.join(output).encode('ascii')) + conn.close() + + def start(self, interval: float = 60.0, prefix: str = '') -> None: + t = _RegularPush(self, interval, prefix) + t.daemon = True + t.start() diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__init__.py b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87e0b8a6b90dd3841077ebefb650ad8c0b4ae33a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__init__.py @@ -0,0 +1,3 @@ +from ._exposition import MetricsResource + +__all__ = ['MetricsResource'] diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a763bd9d549f7c271d8e1047062daf69c707d3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__pycache__/_exposition.cpython-311.pyc b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__pycache__/_exposition.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd67e9d6a05a2e8ba7c181c91c599d39c7017fc3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/__pycache__/_exposition.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/prometheus_client/twisted/_exposition.py b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/_exposition.py new file mode 100644 index 0000000000000000000000000000000000000000..202a7d3bbde9325fce3bdc5c143e76cb66359380 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/prometheus_client/twisted/_exposition.py @@ -0,0 +1,8 @@ +from twisted.internet import reactor +from twisted.web.wsgi import WSGIResource + +from .. import exposition, REGISTRY + +MetricsResource = lambda registry=REGISTRY: WSGIResource( + reactor, reactor.getThreadPool(), exposition.make_wsgi_app(registry) +) diff --git a/.venv/lib/python3.11/site-packages/referencing/__init__.py b/.venv/lib/python3.11/site-packages/referencing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e09207d7e4b90aba221181d87886fd4f54038abf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/__init__.py @@ -0,0 +1,7 @@ +""" +Cross-specification, implementation-agnostic JSON referencing. +""" + +from referencing._core import Anchor, Registry, Resource, Specification + +__all__ = ["Anchor", "Registry", "Resource", "Specification"] diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d30bec32b353c5149084bfeb3c740ca9fe5df5b5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/_attrs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/_attrs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34a68dae10ffcc7b7e4108ee0d09b726eb3d8388 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/_attrs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/_core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/_core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34bc7829d76090518b0d1e407aaf9f92ad6557ef Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/_core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a90943b66815043f639b92daa7852aee4bd0966c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/jsonschema.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/jsonschema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52c18dd071048d0d5785c9320326db5dfc7f7226 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/jsonschema.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/retrieval.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/retrieval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68d148cf71ac51c32fae245c27f15470239438ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/retrieval.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/__pycache__/typing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/__pycache__/typing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d16140f23ebfbfabba415b2f06d77f89f4099c2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/__pycache__/typing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/_attrs.py b/.venv/lib/python3.11/site-packages/referencing/_attrs.py new file mode 100644 index 0000000000000000000000000000000000000000..ae85b865fed622afe83e8d6b7b17a1f0d174aba3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/_attrs.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import NoReturn, TypeVar + +from attrs import define as _define, frozen as _frozen + +_T = TypeVar("_T") + + +def define(cls: type[_T]) -> type[_T]: # pragma: no cover + cls.__init_subclass__ = _do_not_subclass + return _define(cls) + + +def frozen(cls: type[_T]) -> type[_T]: + cls.__init_subclass__ = _do_not_subclass + return _frozen(cls) + + +class UnsupportedSubclassing(Exception): + def __str__(self): + return ( + "Subclassing is not part of referencing's public API. " + "If no other suitable API exists for what you're trying to do, " + "feel free to file an issue asking for one." + ) + + +@staticmethod +def _do_not_subclass() -> NoReturn: # pragma: no cover + raise UnsupportedSubclassing() diff --git a/.venv/lib/python3.11/site-packages/referencing/_attrs.pyi b/.venv/lib/python3.11/site-packages/referencing/_attrs.pyi new file mode 100644 index 0000000000000000000000000000000000000000..278e4109b622dc3ecab7e3e0d0562ba594b80a33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/_attrs.pyi @@ -0,0 +1,20 @@ +from typing import Any, Callable, TypeVar, Union + +from attr import attrib, field + +class UnsupportedSubclassing(Exception): ... + +_T = TypeVar("_T") + +def __dataclass_transform__( + *, + frozen_default: bool = False, + field_descriptors: tuple[Union[type, Callable[..., Any]], ...] = ..., +) -> Callable[[_T], _T]: ... +@__dataclass_transform__(field_descriptors=(attrib, field)) +def define(cls: type[_T]) -> type[_T]: ... +@__dataclass_transform__( + frozen_default=True, + field_descriptors=(attrib, field), +) +def frozen(cls: type[_T]) -> type[_T]: ... diff --git a/.venv/lib/python3.11/site-packages/referencing/_core.py b/.venv/lib/python3.11/site-packages/referencing/_core.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2d51bdc4c47e270502bdb22fe006135cd9c501 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/_core.py @@ -0,0 +1,739 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Sequence +from enum import Enum +from typing import Any, Callable, ClassVar, Generic, Protocol +from urllib.parse import unquote, urldefrag, urljoin + +from attrs import evolve, field +from rpds import HashTrieMap, HashTrieSet, List + +try: + from typing_extensions import TypeVar +except ImportError: # pragma: no cover + from typing import TypeVar + +from referencing import exceptions +from referencing._attrs import frozen +from referencing.typing import URI, Anchor as AnchorType, D, Mapping, Retrieve + +EMPTY_UNCRAWLED: HashTrieSet[URI] = HashTrieSet() +EMPTY_PREVIOUS_RESOLVERS: List[URI] = List() + + +class _Unset(Enum): + """ + What sillyness... + """ + + SENTINEL = 1 + + +_UNSET = _Unset.SENTINEL + + +class _MaybeInSubresource(Protocol[D]): + def __call__( + self, + segments: Sequence[int | str], + resolver: Resolver[D], + subresource: Resource[D], + ) -> Resolver[D]: ... + + +def _detect_or_error(contents: D) -> Specification[D]: + if not isinstance(contents, Mapping): + raise exceptions.CannotDetermineSpecification(contents) + + jsonschema_dialect_id = contents.get("$schema") # type: ignore[reportUnknownMemberType] + if not isinstance(jsonschema_dialect_id, str): + raise exceptions.CannotDetermineSpecification(contents) + + from referencing.jsonschema import specification_with + + return specification_with(jsonschema_dialect_id) + + +def _detect_or_default( + default: Specification[D], +) -> Callable[[D], Specification[D]]: + def _detect(contents: D) -> Specification[D]: + if not isinstance(contents, Mapping): + return default + + jsonschema_dialect_id = contents.get("$schema") # type: ignore[reportUnknownMemberType] + if jsonschema_dialect_id is None: + return default + + from referencing.jsonschema import specification_with + + return specification_with( + jsonschema_dialect_id, # type: ignore[reportUnknownArgumentType] + default=default, + ) + + return _detect + + +class _SpecificationDetector: + def __get__( + self, + instance: Specification[D] | None, + cls: type[Specification[D]], + ) -> Callable[[D], Specification[D]]: + if instance is None: + return _detect_or_error + else: + return _detect_or_default(instance) + + +@frozen +class Specification(Generic[D]): + """ + A specification which defines referencing behavior. + + The various methods of a `Specification` allow for varying referencing + behavior across JSON Schema specification versions, etc. + """ + + #: A short human-readable name for the specification, used for debugging. + name: str + + #: Find the ID of a given document. + id_of: Callable[[D], URI | None] + + #: Retrieve the subresources of the given document (without traversing into + #: the subresources themselves). + subresources_of: Callable[[D], Iterable[D]] + + #: While resolving a JSON pointer, conditionally enter a subresource + #: (if e.g. we have just entered a keyword whose value is a subresource) + maybe_in_subresource: _MaybeInSubresource[D] + + #: Retrieve the anchors contained in the given document. + _anchors_in: Callable[ + [Specification[D], D], + Iterable[AnchorType[D]], + ] = field(alias="anchors_in") + + #: An opaque specification where resources have no subresources + #: nor internal identifiers. + OPAQUE: ClassVar[Specification[Any]] + + #: Attempt to discern which specification applies to the given contents. + #: + #: May be called either as an instance method or as a class method, with + #: slightly different behavior in the following case: + #: + #: Recall that not all contents contains enough internal information about + #: which specification it is written for -- the JSON Schema ``{}``, + #: for instance, is valid under many different dialects and may be + #: interpreted as any one of them. + #: + #: When this method is used as an instance method (i.e. called on a + #: specific specification), that specification is used as the default + #: if the given contents are unidentifiable. + #: + #: On the other hand when called as a class method, an error is raised. + #: + #: To reiterate, ``DRAFT202012.detect({})`` will return ``DRAFT202012`` + #: whereas the class method ``Specification.detect({})`` will raise an + #: error. + #: + #: (Note that of course ``DRAFT202012.detect(...)`` may return some other + #: specification when given a schema which *does* identify as being for + #: another version). + #: + #: Raises: + #: + #: `CannotDetermineSpecification` + #: + #: if the given contents don't have any discernible + #: information which could be used to guess which + #: specification they identify as + detect = _SpecificationDetector() + + def __repr__(self) -> str: + return f"" + + def anchors_in(self, contents: D): + """ + Retrieve the anchors contained in the given document. + """ + return self._anchors_in(self, contents) + + def create_resource(self, contents: D) -> Resource[D]: + """ + Create a resource which is interpreted using this specification. + """ + return Resource(contents=contents, specification=self) + + +Specification.OPAQUE = Specification( + name="opaque", + id_of=lambda contents: None, + subresources_of=lambda contents: [], + anchors_in=lambda specification, contents: [], + maybe_in_subresource=lambda segments, resolver, subresource: resolver, +) + + +@frozen +class Resource(Generic[D]): + r""" + A document (deserialized JSON) with a concrete interpretation under a spec. + + In other words, a Python object, along with an instance of `Specification` + which describes how the document interacts with referencing -- both + internally (how it refers to other `Resource`\ s) and externally (how it + should be identified such that it is referenceable by other documents). + """ + + contents: D + _specification: Specification[D] = field(alias="specification") + + @classmethod + def from_contents( + cls, + contents: D, + default_specification: ( + type[Specification[D]] | Specification[D] + ) = Specification, + ) -> Resource[D]: + """ + Create a resource guessing which specification applies to the contents. + + Raises: + + `CannotDetermineSpecification` + + if the given contents don't have any discernible + information which could be used to guess which + specification they identify as + + """ + specification = default_specification.detect(contents) + return specification.create_resource(contents=contents) + + @classmethod + def opaque(cls, contents: D) -> Resource[D]: + """ + Create an opaque `Resource` -- i.e. one with opaque specification. + + See `Specification.OPAQUE` for details. + """ + return Specification.OPAQUE.create_resource(contents=contents) + + def id(self) -> URI | None: + """ + Retrieve this resource's (specification-specific) identifier. + """ + id = self._specification.id_of(self.contents) + if id is None: + return + return id.rstrip("#") + + def subresources(self) -> Iterable[Resource[D]]: + """ + Retrieve this resource's subresources. + """ + return ( + Resource.from_contents( + each, + default_specification=self._specification, + ) + for each in self._specification.subresources_of(self.contents) + ) + + def anchors(self) -> Iterable[AnchorType[D]]: + """ + Retrieve this resource's (specification-specific) identifier. + """ + return self._specification.anchors_in(self.contents) + + def pointer(self, pointer: str, resolver: Resolver[D]) -> Resolved[D]: + """ + Resolve the given JSON pointer. + + Raises: + + `exceptions.PointerToNowhere` + + if the pointer points to a location not present in the document + + """ + if not pointer: + return Resolved(contents=self.contents, resolver=resolver) + + contents = self.contents + segments: list[int | str] = [] + for segment in unquote(pointer[1:]).split("/"): + if isinstance(contents, Sequence): + segment = int(segment) + else: + segment = segment.replace("~1", "/").replace("~0", "~") + try: + contents = contents[segment] # type: ignore[reportUnknownArgumentType] + except LookupError as lookup_error: + error = exceptions.PointerToNowhere(ref=pointer, resource=self) + raise error from lookup_error + + segments.append(segment) + last = resolver + resolver = self._specification.maybe_in_subresource( + segments=segments, + resolver=resolver, + subresource=self._specification.create_resource(contents), + ) + if resolver is not last: + segments = [] + return Resolved(contents=contents, resolver=resolver) # type: ignore[reportUnknownArgumentType] + + +def _fail_to_retrieve(uri: URI): + raise exceptions.NoSuchResource(ref=uri) + + +@frozen +class Registry(Mapping[URI, Resource[D]]): + r""" + A registry of `Resource`\ s, each identified by their canonical URIs. + + Registries store a collection of in-memory resources, and optionally + enable additional resources which may be stored elsewhere (e.g. in a + database, a separate set of files, over the network, etc.). + + They also lazily walk their known resources, looking for subresources + within them. In other words, subresources contained within any added + resources will be retrievable via their own IDs (though this discovery of + subresources will be delayed until necessary). + + Registries are immutable, and their methods return new instances of the + registry with the additional resources added to them. + + The ``retrieve`` argument can be used to configure retrieval of resources + dynamically, either over the network, from a database, or the like. + Pass it a callable which will be called if any URI not present in the + registry is accessed. It must either return a `Resource` or else raise a + `NoSuchResource` exception indicating that the resource does not exist + even according to the retrieval logic. + """ + + _resources: HashTrieMap[URI, Resource[D]] = field( + default=HashTrieMap(), + converter=HashTrieMap.convert, # type: ignore[reportGeneralTypeIssues] + alias="resources", + ) + _anchors: HashTrieMap[tuple[URI, str], AnchorType[D]] = HashTrieMap() + _uncrawled: HashTrieSet[URI] = EMPTY_UNCRAWLED + _retrieve: Retrieve[D] = field(default=_fail_to_retrieve, alias="retrieve") + + def __getitem__(self, uri: URI) -> Resource[D]: + """ + Return the (already crawled) `Resource` identified by the given URI. + """ + try: + return self._resources[uri.rstrip("#")] + except KeyError: + raise exceptions.NoSuchResource(ref=uri) from None + + def __iter__(self) -> Iterator[URI]: + """ + Iterate over all crawled URIs in the registry. + """ + return iter(self._resources) + + def __len__(self) -> int: + """ + Count the total number of fully crawled resources in this registry. + """ + return len(self._resources) + + def __rmatmul__( + self, + new: Resource[D] | Iterable[Resource[D]], + ) -> Registry[D]: + """ + Create a new registry with resource(s) added using their internal IDs. + + Resources must have a internal IDs (e.g. the :kw:`$id` keyword in + modern JSON Schema versions), otherwise an error will be raised. + + Both a single resource as well as an iterable of resources works, i.e.: + + * ``resource @ registry`` or + + * ``[iterable, of, multiple, resources] @ registry`` + + which -- again, assuming the resources have internal IDs -- is + equivalent to calling `Registry.with_resources` as such: + + .. code:: python + + registry.with_resources( + (resource.id(), resource) for resource in new_resources + ) + + Raises: + + `NoInternalID` + + if the resource(s) in fact do not have IDs + + """ + if isinstance(new, Resource): + new = (new,) + + resources = self._resources + uncrawled = self._uncrawled + for resource in new: + id = resource.id() + if id is None: + raise exceptions.NoInternalID(resource=resource) + uncrawled = uncrawled.insert(id) + resources = resources.insert(id, resource) + return evolve(self, resources=resources, uncrawled=uncrawled) + + def __repr__(self) -> str: + size = len(self) + pluralized = "resource" if size == 1 else "resources" + if self._uncrawled: + uncrawled = len(self._uncrawled) + if uncrawled == size: + summary = f"uncrawled {pluralized}" + else: + summary = f"{pluralized}, {uncrawled} uncrawled" + else: + summary = f"{pluralized}" + return f"" + + def get_or_retrieve(self, uri: URI) -> Retrieved[D, Resource[D]]: + """ + Get a resource from the registry, crawling or retrieving if necessary. + + May involve crawling to find the given URI if it is not already known, + so the returned object is a `Retrieved` object which contains both the + resource value as well as the registry which ultimately contained it. + """ + resource = self._resources.get(uri) + if resource is not None: + return Retrieved(registry=self, value=resource) + + registry = self.crawl() + resource = registry._resources.get(uri) + if resource is not None: + return Retrieved(registry=registry, value=resource) + + try: + resource = registry._retrieve(uri) + except ( + exceptions.CannotDetermineSpecification, + exceptions.NoSuchResource, + ): + raise + except Exception as error: + raise exceptions.Unretrievable(ref=uri) from error + else: + registry = registry.with_resource(uri, resource) + return Retrieved(registry=registry, value=resource) + + def remove(self, uri: URI): + """ + Return a registry with the resource identified by a given URI removed. + """ + if uri not in self._resources: + raise exceptions.NoSuchResource(ref=uri) + + return evolve( + self, + resources=self._resources.remove(uri), + uncrawled=self._uncrawled.discard(uri), + anchors=HashTrieMap( + (k, v) for k, v in self._anchors.items() if k[0] != uri + ), + ) + + def anchor(self, uri: URI, name: str): + """ + Retrieve a given anchor from a resource which must already be crawled. + """ + value = self._anchors.get((uri, name)) + if value is not None: + return Retrieved(value=value, registry=self) + + registry = self.crawl() + value = registry._anchors.get((uri, name)) + if value is not None: + return Retrieved(value=value, registry=registry) + + resource = self[uri] + canonical_uri = resource.id() + if canonical_uri is not None: + value = registry._anchors.get((canonical_uri, name)) + if value is not None: + return Retrieved(value=value, registry=registry) + + if "/" in name: + raise exceptions.InvalidAnchor( + ref=uri, + resource=resource, + anchor=name, + ) + raise exceptions.NoSuchAnchor(ref=uri, resource=resource, anchor=name) + + def contents(self, uri: URI) -> D: + """ + Retrieve the (already crawled) contents identified by the given URI. + """ + return self[uri].contents + + def crawl(self) -> Registry[D]: + """ + Crawl all added resources, discovering subresources. + """ + resources = self._resources + anchors = self._anchors + uncrawled = [(uri, resources[uri]) for uri in self._uncrawled] + while uncrawled: + uri, resource = uncrawled.pop() + + id = resource.id() + if id is not None: + uri = urljoin(uri, id) + resources = resources.insert(uri, resource) + for each in resource.anchors(): + anchors = anchors.insert((uri, each.name), each) + uncrawled.extend((uri, each) for each in resource.subresources()) + return evolve( + self, + resources=resources, + anchors=anchors, + uncrawled=EMPTY_UNCRAWLED, + ) + + def with_resource(self, uri: URI, resource: Resource[D]): + """ + Add the given `Resource` to the registry, without crawling it. + """ + return self.with_resources([(uri, resource)]) + + def with_resources( + self, + pairs: Iterable[tuple[URI, Resource[D]]], + ) -> Registry[D]: + r""" + Add the given `Resource`\ s to the registry, without crawling them. + """ + resources = self._resources + uncrawled = self._uncrawled + for uri, resource in pairs: + # Empty fragment URIs are equivalent to URIs without the fragment. + # TODO: Is this true for non JSON Schema resources? Probably not. + uri = uri.rstrip("#") + uncrawled = uncrawled.insert(uri) + resources = resources.insert(uri, resource) + return evolve(self, resources=resources, uncrawled=uncrawled) + + def with_contents( + self, + pairs: Iterable[tuple[URI, D]], + **kwargs: Any, + ) -> Registry[D]: + r""" + Add the given contents to the registry, autodetecting when necessary. + """ + return self.with_resources( + (uri, Resource.from_contents(each, **kwargs)) + for uri, each in pairs + ) + + def combine(self, *registries: Registry[D]) -> Registry[D]: + """ + Combine together one or more other registries, producing a unified one. + """ + if registries == (self,): + return self + resources = self._resources + anchors = self._anchors + uncrawled = self._uncrawled + retrieve = self._retrieve + for registry in registries: + resources = resources.update(registry._resources) + anchors = anchors.update(registry._anchors) + uncrawled = uncrawled.update(registry._uncrawled) + + if registry._retrieve is not _fail_to_retrieve: # type: ignore[reportUnnecessaryComparison] ??? + if registry._retrieve is not retrieve is not _fail_to_retrieve: # type: ignore[reportUnnecessaryComparison] ??? + raise ValueError( # noqa: TRY003 + "Cannot combine registries with conflicting retrieval " + "functions.", + ) + retrieve = registry._retrieve + return evolve( + self, + anchors=anchors, + resources=resources, + uncrawled=uncrawled, + retrieve=retrieve, + ) + + def resolver(self, base_uri: URI = "") -> Resolver[D]: + """ + Return a `Resolver` which resolves references against this registry. + """ + return Resolver(base_uri=base_uri, registry=self) + + def resolver_with_root(self, resource: Resource[D]) -> Resolver[D]: + """ + Return a `Resolver` with a specific root resource. + """ + uri = resource.id() or "" + return Resolver( + base_uri=uri, + registry=self.with_resource(uri, resource), + ) + + +#: An anchor or resource. +AnchorOrResource = TypeVar( + "AnchorOrResource", + AnchorType[Any], + Resource[Any], + default=Resource[Any], +) + + +@frozen +class Retrieved(Generic[D, AnchorOrResource]): + """ + A value retrieved from a `Registry`. + """ + + value: AnchorOrResource + registry: Registry[D] + + +@frozen +class Resolved(Generic[D]): + """ + A reference resolved to its contents by a `Resolver`. + """ + + contents: D + resolver: Resolver[D] + + +@frozen +class Resolver(Generic[D]): + """ + A reference resolver. + + Resolvers help resolve references (including relative ones) by + pairing a fixed base URI with a `Registry`. + + This object, under normal circumstances, is expected to be used by + *implementers of libraries* built on top of `referencing` (e.g. JSON Schema + implementations or other libraries resolving JSON references), + not directly by end-users populating registries or while writing + schemas or other resources. + + References are resolved against the base URI, and the combined URI + is then looked up within the registry. + + The process of resolving a reference may itself involve calculating + a *new* base URI for future reference resolution (e.g. if an + intermediate resource sets a new base URI), or may involve encountering + additional subresources and adding them to a new registry. + """ + + _base_uri: URI = field(alias="base_uri") + _registry: Registry[D] = field(alias="registry") + _previous: List[URI] = field(default=List(), repr=False, alias="previous") + + def lookup(self, ref: URI) -> Resolved[D]: + """ + Resolve the given reference to the resource it points to. + + Raises: + + `exceptions.Unresolvable` + + or a subclass thereof (see below) if the reference isn't + resolvable + + `exceptions.NoSuchAnchor` + + if the reference is to a URI where a resource exists but + contains a plain name fragment which does not exist within + the resource + + `exceptions.PointerToNowhere` + + if the reference is to a URI where a resource exists but + contains a JSON pointer to a location within the resource + that does not exist + + """ + if ref.startswith("#"): + uri, fragment = self._base_uri, ref[1:] + else: + uri, fragment = urldefrag(urljoin(self._base_uri, ref)) + try: + retrieved = self._registry.get_or_retrieve(uri) + except exceptions.NoSuchResource: + raise exceptions.Unresolvable(ref=ref) from None + except exceptions.Unretrievable as error: + raise exceptions.Unresolvable(ref=ref) from error + + if fragment.startswith("/"): + resolver = self._evolve(registry=retrieved.registry, base_uri=uri) + return retrieved.value.pointer(pointer=fragment, resolver=resolver) + + if fragment: + retrieved = retrieved.registry.anchor(uri, fragment) + resolver = self._evolve(registry=retrieved.registry, base_uri=uri) + return retrieved.value.resolve(resolver=resolver) + + resolver = self._evolve(registry=retrieved.registry, base_uri=uri) + return Resolved(contents=retrieved.value.contents, resolver=resolver) + + def in_subresource(self, subresource: Resource[D]) -> Resolver[D]: + """ + Create a resolver for a subresource (which may have a new base URI). + """ + id = subresource.id() + if id is None: + return self + return evolve(self, base_uri=urljoin(self._base_uri, id)) + + def dynamic_scope(self) -> Iterable[tuple[URI, Registry[D]]]: + """ + In specs with such a notion, return the URIs in the dynamic scope. + """ + for uri in self._previous: + yield uri, self._registry + + def _evolve(self, base_uri: URI, **kwargs: Any): + """ + Evolve, appending to the dynamic scope. + """ + previous = self._previous + if self._base_uri and (not previous or base_uri != self._base_uri): + previous = previous.push_front(self._base_uri) + return evolve(self, base_uri=base_uri, previous=previous, **kwargs) + + +@frozen +class Anchor(Generic[D]): + """ + A simple anchor in a `Resource`. + """ + + name: str + resource: Resource[D] + + def resolve(self, resolver: Resolver[D]): + """ + Return the resource for this anchor. + """ + return Resolved(contents=self.resource.contents, resolver=resolver) diff --git a/.venv/lib/python3.11/site-packages/referencing/exceptions.py b/.venv/lib/python3.11/site-packages/referencing/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..3267fc70732e73c0a888d9f60551ad9373ed6d16 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/exceptions.py @@ -0,0 +1,165 @@ +""" +Errors, oh no! +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import attrs + +from referencing._attrs import frozen + +if TYPE_CHECKING: + from referencing import Resource + from referencing.typing import URI + + +@frozen +class NoSuchResource(KeyError): + """ + The given URI is not present in a registry. + + Unlike most exceptions, this class *is* intended to be publicly + instantiable and *is* part of the public API of the package. + """ + + ref: URI + + def __eq__(self, other: object) -> bool: + if self.__class__ is not other.__class__: + return NotImplemented + return attrs.astuple(self) == attrs.astuple(other) + + def __hash__(self) -> int: + return hash(attrs.astuple(self)) + + +@frozen +class NoInternalID(Exception): + """ + A resource has no internal ID, but one is needed. + + E.g. in modern JSON Schema drafts, this is the :kw:`$id` keyword. + + One might be needed if a resource was to-be added to a registry but no + other URI is available, and the resource doesn't declare its canonical URI. + """ + + resource: Resource[Any] + + def __eq__(self, other: object) -> bool: + if self.__class__ is not other.__class__: + return NotImplemented + return attrs.astuple(self) == attrs.astuple(other) + + def __hash__(self) -> int: + return hash(attrs.astuple(self)) + + +@frozen +class Unretrievable(KeyError): + """ + The given URI is not present in a registry, and retrieving it failed. + """ + + ref: URI + + def __eq__(self, other: object) -> bool: + if self.__class__ is not other.__class__: + return NotImplemented + return attrs.astuple(self) == attrs.astuple(other) + + def __hash__(self) -> int: + return hash(attrs.astuple(self)) + + +@frozen +class CannotDetermineSpecification(Exception): + """ + Attempting to detect the appropriate `Specification` failed. + + This happens if no discernible information is found in the contents of the + new resource which would help identify it. + """ + + contents: Any + + def __eq__(self, other: object) -> bool: + if self.__class__ is not other.__class__: + return NotImplemented + return attrs.astuple(self) == attrs.astuple(other) + + def __hash__(self) -> int: + return hash(attrs.astuple(self)) + + +@attrs.frozen # Because here we allow subclassing below. +class Unresolvable(Exception): + """ + A reference was unresolvable. + """ + + ref: URI + + def __eq__(self, other: object) -> bool: + if self.__class__ is not other.__class__: + return NotImplemented + return attrs.astuple(self) == attrs.astuple(other) + + def __hash__(self) -> int: + return hash(attrs.astuple(self)) + + +@frozen +class PointerToNowhere(Unresolvable): + """ + A JSON Pointer leads to a part of a document that does not exist. + """ + + resource: Resource[Any] + + def __str__(self) -> str: + msg = f"{self.ref!r} does not exist within {self.resource.contents!r}" + if self.ref == "/": + msg += ( + ". The pointer '/' is a valid JSON Pointer but it points to " + "an empty string property ''. If you intended to point " + "to the entire resource, you should use '#'." + ) + return msg + + +@frozen +class NoSuchAnchor(Unresolvable): + """ + An anchor does not exist within a particular resource. + """ + + resource: Resource[Any] + anchor: str + + def __str__(self) -> str: + return ( + f"{self.anchor!r} does not exist within {self.resource.contents!r}" + ) + + +@frozen +class InvalidAnchor(Unresolvable): + """ + An anchor which could never exist in a resource was dereferenced. + + It is somehow syntactically invalid. + """ + + resource: Resource[Any] + anchor: str + + def __str__(self) -> str: + return ( + f"'#{self.anchor}' is not a valid anchor, neither as a " + "plain name anchor nor as a JSON Pointer. You may have intended " + f"to use '#/{self.anchor}', as the slash is required *before each " + "segment* of a JSON pointer." + ) diff --git a/.venv/lib/python3.11/site-packages/referencing/jsonschema.py b/.venv/lib/python3.11/site-packages/referencing/jsonschema.py new file mode 100644 index 0000000000000000000000000000000000000000..169e109d914e558ec3693cef5ecdcd4dc82aedaa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/jsonschema.py @@ -0,0 +1,642 @@ +""" +Referencing implementations for JSON Schema specs (historic & current). +""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence, Set +from typing import Any, Union + +from referencing import Anchor, Registry, Resource, Specification, exceptions +from referencing._attrs import frozen +from referencing._core import ( + _UNSET, # type: ignore[reportPrivateUsage] + Resolved as _Resolved, + Resolver as _Resolver, + _Unset, # type: ignore[reportPrivateUsage] +) +from referencing.typing import URI, Anchor as AnchorType, Mapping + +#: A JSON Schema which is a JSON object +ObjectSchema = Mapping[str, Any] + +#: A JSON Schema of any kind +Schema = Union[bool, ObjectSchema] + +#: A Resource whose contents are JSON Schemas +SchemaResource = Resource[Schema] + +#: A JSON Schema Registry +SchemaRegistry = Registry[Schema] + +#: The empty JSON Schema Registry +EMPTY_REGISTRY: SchemaRegistry = Registry() + + +@frozen +class UnknownDialect(Exception): + """ + A dialect identifier was found for a dialect unknown by this library. + + If it's a custom ("unofficial") dialect, be sure you've registered it. + """ + + uri: URI + + +def _dollar_id(contents: Schema) -> URI | None: + if isinstance(contents, bool): + return + return contents.get("$id") + + +def _legacy_dollar_id(contents: Schema) -> URI | None: + if isinstance(contents, bool) or "$ref" in contents: + return + id = contents.get("$id") + if id is not None and not id.startswith("#"): + return id + + +def _legacy_id(contents: ObjectSchema) -> URI | None: + if "$ref" in contents: + return + id = contents.get("id") + if id is not None and not id.startswith("#"): + return id + + +def _anchor( + specification: Specification[Schema], + contents: Schema, +) -> Iterable[AnchorType[Schema]]: + if isinstance(contents, bool): + return + anchor = contents.get("$anchor") + if anchor is not None: + yield Anchor( + name=anchor, + resource=specification.create_resource(contents), + ) + + dynamic_anchor = contents.get("$dynamicAnchor") + if dynamic_anchor is not None: + yield DynamicAnchor( + name=dynamic_anchor, + resource=specification.create_resource(contents), + ) + + +def _anchor_2019( + specification: Specification[Schema], + contents: Schema, +) -> Iterable[Anchor[Schema]]: + if isinstance(contents, bool): + return [] + anchor = contents.get("$anchor") + if anchor is None: + return [] + return [ + Anchor( + name=anchor, + resource=specification.create_resource(contents), + ), + ] + + +def _legacy_anchor_in_dollar_id( + specification: Specification[Schema], + contents: Schema, +) -> Iterable[Anchor[Schema]]: + if isinstance(contents, bool): + return [] + id = contents.get("$id", "") + if not id.startswith("#"): + return [] + return [ + Anchor( + name=id[1:], + resource=specification.create_resource(contents), + ), + ] + + +def _legacy_anchor_in_id( + specification: Specification[ObjectSchema], + contents: ObjectSchema, +) -> Iterable[Anchor[ObjectSchema]]: + id = contents.get("id", "") + if not id.startswith("#"): + return [] + return [ + Anchor( + name=id[1:], + resource=specification.create_resource(contents), + ), + ] + + +def _subresources_of( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + """ + Create a callable returning JSON Schema specification-style subschemas. + + Relies on specifying the set of keywords containing subschemas in their + values, in a subobject's values, or in a subarray. + """ + + def subresources_of(contents: Schema) -> Iterable[ObjectSchema]: + if isinstance(contents, bool): + return + for each in in_value: + if each in contents: + yield contents[each] + for each in in_subarray: + if each in contents: + yield from contents[each] + for each in in_subvalues: + if each in contents: + yield from contents[each].values() + + return subresources_of + + +def _subresources_of_with_crazy_items( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + """ + Specifically handle older drafts where there are some funky keywords. + """ + + def subresources_of(contents: Schema) -> Iterable[ObjectSchema]: + if isinstance(contents, bool): + return + for each in in_value: + if each in contents: + yield contents[each] + for each in in_subarray: + if each in contents: + yield from contents[each] + for each in in_subvalues: + if each in contents: + yield from contents[each].values() + + items = contents.get("items") + if items is not None: + if isinstance(items, Sequence): + yield from items + else: + yield items + + return subresources_of + + +def _subresources_of_with_crazy_items_dependencies( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + """ + Specifically handle older drafts where there are some funky keywords. + """ + + def subresources_of(contents: Schema) -> Iterable[ObjectSchema]: + if isinstance(contents, bool): + return + for each in in_value: + if each in contents: + yield contents[each] + for each in in_subarray: + if each in contents: + yield from contents[each] + for each in in_subvalues: + if each in contents: + yield from contents[each].values() + + items = contents.get("items") + if items is not None: + if isinstance(items, Sequence): + yield from items + else: + yield items + dependencies = contents.get("dependencies") + if dependencies is not None: + values = iter(dependencies.values()) + value = next(values, None) + if isinstance(value, Mapping): + yield value + yield from values + + return subresources_of + + +def _subresources_of_with_crazy_aP_items_dependencies( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + """ + Specifically handle even older drafts where there are some funky keywords. + """ + + def subresources_of(contents: ObjectSchema) -> Iterable[ObjectSchema]: + for each in in_value: + if each in contents: + yield contents[each] + for each in in_subarray: + if each in contents: + yield from contents[each] + for each in in_subvalues: + if each in contents: + yield from contents[each].values() + + items = contents.get("items") + if items is not None: + if isinstance(items, Sequence): + yield from items + else: + yield items + dependencies = contents.get("dependencies") + if dependencies is not None: + values = iter(dependencies.values()) + value = next(values, None) + if isinstance(value, Mapping): + yield value + yield from values + + for each in "additionalItems", "additionalProperties": + value = contents.get(each) + if isinstance(value, Mapping): + yield value + + return subresources_of + + +def _maybe_in_subresource( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + in_child = in_subvalues | in_subarray + + def maybe_in_subresource( + segments: Sequence[int | str], + resolver: _Resolver[Any], + subresource: Resource[Any], + ) -> _Resolver[Any]: + _segments = iter(segments) + for segment in _segments: + if segment not in in_value and ( + segment not in in_child or next(_segments, None) is None + ): + return resolver + return resolver.in_subresource(subresource) + + return maybe_in_subresource + + +def _maybe_in_subresource_crazy_items( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + in_child = in_subvalues | in_subarray + + def maybe_in_subresource( + segments: Sequence[int | str], + resolver: _Resolver[Any], + subresource: Resource[Any], + ) -> _Resolver[Any]: + _segments = iter(segments) + for segment in _segments: + if segment == "items" and isinstance( + subresource.contents, + Mapping, + ): + return resolver.in_subresource(subresource) + if segment not in in_value and ( + segment not in in_child or next(_segments, None) is None + ): + return resolver + return resolver.in_subresource(subresource) + + return maybe_in_subresource + + +def _maybe_in_subresource_crazy_items_dependencies( + in_value: Set[str] = frozenset(), + in_subvalues: Set[str] = frozenset(), + in_subarray: Set[str] = frozenset(), +): + in_child = in_subvalues | in_subarray + + def maybe_in_subresource( + segments: Sequence[int | str], + resolver: _Resolver[Any], + subresource: Resource[Any], + ) -> _Resolver[Any]: + _segments = iter(segments) + for segment in _segments: + if segment in {"items", "dependencies"} and isinstance( + subresource.contents, + Mapping, + ): + return resolver.in_subresource(subresource) + if segment not in in_value and ( + segment not in in_child or next(_segments, None) is None + ): + return resolver + return resolver.in_subresource(subresource) + + return maybe_in_subresource + + +#: JSON Schema draft 2020-12 +DRAFT202012 = Specification( + name="draft2020-12", + id_of=_dollar_id, + subresources_of=_subresources_of( + in_value={ + "additionalProperties", + "contains", + "contentSchema", + "else", + "if", + "items", + "not", + "propertyNames", + "then", + "unevaluatedItems", + "unevaluatedProperties", + }, + in_subarray={"allOf", "anyOf", "oneOf", "prefixItems"}, + in_subvalues={ + "$defs", + "definitions", + "dependentSchemas", + "patternProperties", + "properties", + }, + ), + anchors_in=_anchor, + maybe_in_subresource=_maybe_in_subresource( + in_value={ + "additionalProperties", + "contains", + "contentSchema", + "else", + "if", + "items", + "not", + "propertyNames", + "then", + "unevaluatedItems", + "unevaluatedProperties", + }, + in_subarray={"allOf", "anyOf", "oneOf", "prefixItems"}, + in_subvalues={ + "$defs", + "definitions", + "dependentSchemas", + "patternProperties", + "properties", + }, + ), +) +#: JSON Schema draft 2019-09 +DRAFT201909 = Specification( + name="draft2019-09", + id_of=_dollar_id, + subresources_of=_subresources_of_with_crazy_items( + in_value={ + "additionalItems", + "additionalProperties", + "contains", + "contentSchema", + "else", + "if", + "not", + "propertyNames", + "then", + "unevaluatedItems", + "unevaluatedProperties", + }, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={ + "$defs", + "definitions", + "dependentSchemas", + "patternProperties", + "properties", + }, + ), + anchors_in=_anchor_2019, + maybe_in_subresource=_maybe_in_subresource_crazy_items( + in_value={ + "additionalItems", + "additionalProperties", + "contains", + "contentSchema", + "else", + "if", + "not", + "propertyNames", + "then", + "unevaluatedItems", + "unevaluatedProperties", + }, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={ + "$defs", + "definitions", + "dependentSchemas", + "patternProperties", + "properties", + }, + ), +) +#: JSON Schema draft 7 +DRAFT7 = Specification( + name="draft-07", + id_of=_legacy_dollar_id, + subresources_of=_subresources_of_with_crazy_items_dependencies( + in_value={ + "additionalItems", + "additionalProperties", + "contains", + "else", + "if", + "not", + "propertyNames", + "then", + }, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), + anchors_in=_legacy_anchor_in_dollar_id, + maybe_in_subresource=_maybe_in_subresource_crazy_items_dependencies( + in_value={ + "additionalItems", + "additionalProperties", + "contains", + "else", + "if", + "not", + "propertyNames", + "then", + }, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), +) +#: JSON Schema draft 6 +DRAFT6 = Specification( + name="draft-06", + id_of=_legacy_dollar_id, + subresources_of=_subresources_of_with_crazy_items_dependencies( + in_value={ + "additionalItems", + "additionalProperties", + "contains", + "not", + "propertyNames", + }, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), + anchors_in=_legacy_anchor_in_dollar_id, + maybe_in_subresource=_maybe_in_subresource_crazy_items_dependencies( + in_value={ + "additionalItems", + "additionalProperties", + "contains", + "not", + "propertyNames", + }, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), +) +#: JSON Schema draft 4 +DRAFT4 = Specification( + name="draft-04", + id_of=_legacy_id, + subresources_of=_subresources_of_with_crazy_aP_items_dependencies( + in_value={"not"}, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), + anchors_in=_legacy_anchor_in_id, + maybe_in_subresource=_maybe_in_subresource_crazy_items_dependencies( + in_value={"additionalItems", "additionalProperties", "not"}, + in_subarray={"allOf", "anyOf", "oneOf"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), +) +#: JSON Schema draft 3 +DRAFT3 = Specification( + name="draft-03", + id_of=_legacy_id, + subresources_of=_subresources_of_with_crazy_aP_items_dependencies( + in_subarray={"extends"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), + anchors_in=_legacy_anchor_in_id, + maybe_in_subresource=_maybe_in_subresource_crazy_items_dependencies( + in_value={"additionalItems", "additionalProperties"}, + in_subarray={"extends"}, + in_subvalues={"definitions", "patternProperties", "properties"}, + ), +) + + +_SPECIFICATIONS: Registry[Specification[Schema]] = Registry( + { + dialect_id: Resource.opaque(specification) + for dialect_id, specification in [ + ("https://json-schema.org/draft/2020-12/schema", DRAFT202012), + ("https://json-schema.org/draft/2019-09/schema", DRAFT201909), + ("http://json-schema.org/draft-07/schema", DRAFT7), + ("http://json-schema.org/draft-06/schema", DRAFT6), + ("http://json-schema.org/draft-04/schema", DRAFT4), + ("http://json-schema.org/draft-03/schema", DRAFT3), + ] + }, +) + + +def specification_with( + dialect_id: URI, + default: Specification[Any] | _Unset = _UNSET, +) -> Specification[Any]: + """ + Retrieve the `Specification` with the given dialect identifier. + + Raises: + + `UnknownDialect` + + if the given ``dialect_id`` isn't known + + """ + resource = _SPECIFICATIONS.get(dialect_id.rstrip("#")) + if resource is not None: + return resource.contents + if default is _UNSET: + raise UnknownDialect(dialect_id) + return default + + +@frozen +class DynamicAnchor: + """ + Dynamic anchors, introduced in draft 2020. + """ + + name: str + resource: SchemaResource + + def resolve(self, resolver: _Resolver[Schema]) -> _Resolved[Schema]: + """ + Resolve this anchor dynamically. + """ + last = self.resource + for uri, registry in resolver.dynamic_scope(): + try: + anchor = registry.anchor(uri, self.name).value + except exceptions.NoSuchAnchor: + continue + if isinstance(anchor, DynamicAnchor): + last = anchor.resource + return _Resolved( + contents=last.contents, + resolver=resolver.in_subresource(last), + ) + + +def lookup_recursive_ref(resolver: _Resolver[Schema]) -> _Resolved[Schema]: + """ + Recursive references (via recursive anchors), present only in draft 2019. + + As per the 2019 specification (§ 8.2.4.2.1), only the ``#`` recursive + reference is supported (and is therefore assumed to be the relevant + reference). + """ + resolved = resolver.lookup("#") + if isinstance(resolved.contents, Mapping) and resolved.contents.get( + "$recursiveAnchor", + ): + for uri, _ in resolver.dynamic_scope(): + next_resolved = resolver.lookup(uri) + if not isinstance( + next_resolved.contents, + Mapping, + ) or not next_resolved.contents.get("$recursiveAnchor"): + break + resolved = next_resolved + return resolved diff --git a/.venv/lib/python3.11/site-packages/referencing/py.typed b/.venv/lib/python3.11/site-packages/referencing/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/referencing/retrieval.py b/.venv/lib/python3.11/site-packages/referencing/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..53e0512b199fb014d11075ee3047c848ed7c2d69 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/retrieval.py @@ -0,0 +1,92 @@ +""" +Helpers related to (dynamic) resource retrieval. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Callable +import json + +try: + from typing_extensions import TypeVar +except ImportError: # pragma: no cover + from typing import TypeVar + +from referencing import Resource + +if TYPE_CHECKING: + from referencing.typing import URI, D, Retrieve + +#: A serialized document (e.g. a JSON string) +_T = TypeVar("_T", default=str) + + +def to_cached_resource( + cache: Callable[[Retrieve[D]], Retrieve[D]] | None = None, + loads: Callable[[_T], D] = json.loads, + from_contents: Callable[[D], Resource[D]] = Resource.from_contents, +) -> Callable[[Callable[[URI], _T]], Retrieve[D]]: + """ + Create a retriever which caches its return values from a simpler callable. + + Takes a function which returns things like serialized JSON (strings) and + returns something suitable for passing to `Registry` as a retrieve + function. + + This decorator both reduces a small bit of boilerplate for a common case + (deserializing JSON from strings and creating `Resource` objects from the + result) as well as makes the probable need for caching a bit easier. + Retrievers which otherwise do expensive operations (like hitting the + network) might otherwise be called repeatedly. + + Examples + -------- + + .. testcode:: + + from referencing import Registry + from referencing.typing import URI + import referencing.retrieval + + + @referencing.retrieval.to_cached_resource() + def retrieve(uri: URI): + print(f"Retrieved {uri}") + + # Normally, go get some expensive JSON from the network, a file ... + return ''' + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "foo": "bar" + } + ''' + + one = Registry(retrieve=retrieve).get_or_retrieve("urn:example:foo") + print(one.value.contents["foo"]) + + # Retrieving the same URI again reuses the same value (and thus doesn't + # print another retrieval message here) + two = Registry(retrieve=retrieve).get_or_retrieve("urn:example:foo") + print(two.value.contents["foo"]) + + .. testoutput:: + + Retrieved urn:example:foo + bar + bar + + """ + if cache is None: + cache = lru_cache(maxsize=None) + + def decorator(retrieve: Callable[[URI], _T]): + @cache + def cached_retrieve(uri: URI): + response = retrieve(uri) + contents = loads(response) + return from_contents(contents) + + return cached_retrieve + + return decorator diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__init__.py b/.venv/lib/python3.11/site-packages/referencing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82f203a23e741770c65f549179612f83faa92864 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd94f4420f593f31fe11b56a955dd19ae2d7666a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aaeb0b3cc5a29593f1dc607b45182c0817cd3ed Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_jsonschema.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_jsonschema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92e37aac799ccfd33202d70c0c3437621de69a55 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_jsonschema.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_referencing_suite.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_referencing_suite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..145c0085caf1e9483ac62098e2ff272ceebcc03a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_referencing_suite.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_retrieval.cpython-311.pyc b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_retrieval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9707b35eb9b1afa4f02a78511ff98437118af86c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/referencing/tests/__pycache__/test_retrieval.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/test_core.py b/.venv/lib/python3.11/site-packages/referencing/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..3edddbc3d96581e1c74069baa873900495366bab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/tests/test_core.py @@ -0,0 +1,1057 @@ +from rpds import HashTrieMap +import pytest + +from referencing import Anchor, Registry, Resource, Specification, exceptions +from referencing.jsonschema import DRAFT202012 + +ID_AND_CHILDREN = Specification( + name="id-and-children", + id_of=lambda contents: contents.get("ID"), + subresources_of=lambda contents: contents.get("children", []), + anchors_in=lambda specification, contents: [ + Anchor( + name=name, + resource=specification.create_resource(contents=each), + ) + for name, each in contents.get("anchors", {}).items() + ], + maybe_in_subresource=lambda segments, resolver, subresource: ( + resolver.in_subresource(subresource) + if not len(segments) % 2 + and all(each == "children" for each in segments[::2]) + else resolver + ), +) + + +def blow_up(uri): # pragma: no cover + """ + A retriever suitable for use in tests which expect it never to be used. + """ + raise RuntimeError("This retrieve function expects to never be called!") + + +class TestRegistry: + def test_with_resource(self): + """ + Adding a resource to the registry then allows re-retrieving it. + """ + + resource = Resource.opaque(contents={"foo": "bar"}) + uri = "urn:example" + registry = Registry().with_resource(uri=uri, resource=resource) + assert registry[uri] is resource + + def test_with_resources(self): + """ + Adding multiple resources to the registry is like adding each one. + """ + + one = Resource.opaque(contents={}) + two = Resource(contents={"foo": "bar"}, specification=ID_AND_CHILDREN) + registry = Registry().with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/foo/bar", two), + ], + ) + assert registry == Registry().with_resource( + uri="http://example.com/1", + resource=one, + ).with_resource( + uri="http://example.com/foo/bar", + resource=two, + ) + + def test_matmul_resource(self): + uri = "urn:example:resource" + resource = ID_AND_CHILDREN.create_resource({"ID": uri, "foo": 12}) + registry = resource @ Registry() + assert registry == Registry().with_resource(uri, resource) + + def test_matmul_many_resources(self): + one_uri = "urn:example:one" + one = ID_AND_CHILDREN.create_resource({"ID": one_uri, "foo": 12}) + + two_uri = "urn:example:two" + two = ID_AND_CHILDREN.create_resource({"ID": two_uri, "foo": 12}) + + registry = [one, two] @ Registry() + assert registry == Registry().with_resources( + [(one_uri, one), (two_uri, two)], + ) + + def test_matmul_resource_without_id(self): + resource = Resource.opaque(contents={"foo": "bar"}) + with pytest.raises(exceptions.NoInternalID) as e: + resource @ Registry() + assert e.value == exceptions.NoInternalID(resource=resource) + + def test_with_contents_from_json_schema(self): + uri = "urn:example" + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + registry = Registry().with_contents([(uri, schema)]) + + expected = Resource(contents=schema, specification=DRAFT202012) + assert registry[uri] == expected + + def test_with_contents_and_default_specification(self): + uri = "urn:example" + registry = Registry().with_contents( + [(uri, {"foo": "bar"})], + default_specification=Specification.OPAQUE, + ) + assert registry[uri] == Resource.opaque({"foo": "bar"}) + + def test_len(self): + total = 5 + registry = Registry().with_contents( + [(str(i), {"foo": "bar"}) for i in range(total)], + default_specification=Specification.OPAQUE, + ) + assert len(registry) == total + + def test_bool_empty(self): + assert not Registry() + + def test_bool_not_empty(self): + registry = Registry().with_contents( + [(str(i), {"foo": "bar"}) for i in range(3)], + default_specification=Specification.OPAQUE, + ) + assert registry + + def test_iter(self): + registry = Registry().with_contents( + [(str(i), {"foo": "bar"}) for i in range(8)], + default_specification=Specification.OPAQUE, + ) + assert set(registry) == {str(i) for i in range(8)} + + def test_crawl_still_has_top_level_resource(self): + resource = Resource.opaque({"foo": "bar"}) + uri = "urn:example" + registry = Registry({uri: resource}).crawl() + assert registry[uri] is resource + + def test_crawl_finds_a_subresource(self): + child_id = "urn:child" + root = ID_AND_CHILDREN.create_resource( + {"ID": "urn:root", "children": [{"ID": child_id, "foo": 12}]}, + ) + registry = root @ Registry() + with pytest.raises(LookupError): + registry[child_id] + + expected = ID_AND_CHILDREN.create_resource({"ID": child_id, "foo": 12}) + assert registry.crawl()[child_id] == expected + + def test_crawl_finds_anchors_with_id(self): + resource = ID_AND_CHILDREN.create_resource( + {"ID": "urn:bar", "anchors": {"foo": 12}}, + ) + registry = resource @ Registry() + + assert registry.crawl().anchor(resource.id(), "foo").value == Anchor( + name="foo", + resource=ID_AND_CHILDREN.create_resource(12), + ) + + def test_crawl_finds_anchors_no_id(self): + resource = ID_AND_CHILDREN.create_resource({"anchors": {"foo": 12}}) + registry = Registry().with_resource("urn:root", resource) + + assert registry.crawl().anchor("urn:root", "foo").value == Anchor( + name="foo", + resource=ID_AND_CHILDREN.create_resource(12), + ) + + def test_contents(self): + resource = Resource.opaque({"foo": "bar"}) + uri = "urn:example" + registry = Registry().with_resource(uri, resource) + assert registry.contents(uri) == {"foo": "bar"} + + def test_getitem_strips_empty_fragments(self): + uri = "http://example.com/" + resource = ID_AND_CHILDREN.create_resource({"ID": uri + "#"}) + registry = resource @ Registry() + assert registry[uri] == registry[uri + "#"] == resource + + def test_contents_strips_empty_fragments(self): + uri = "http://example.com/" + resource = ID_AND_CHILDREN.create_resource({"ID": uri + "#"}) + registry = resource @ Registry() + assert ( + registry.contents(uri) + == registry.contents(uri + "#") + == {"ID": uri + "#"} + ) + + def test_contents_nonexistent_resource(self): + registry = Registry() + with pytest.raises(exceptions.NoSuchResource) as e: + registry.contents("urn:example") + assert e.value == exceptions.NoSuchResource(ref="urn:example") + + def test_crawled_anchor(self): + resource = ID_AND_CHILDREN.create_resource({"anchors": {"foo": "bar"}}) + registry = Registry().with_resource("urn:example", resource) + retrieved = registry.anchor("urn:example", "foo") + assert retrieved.value == Anchor( + name="foo", + resource=ID_AND_CHILDREN.create_resource("bar"), + ) + assert retrieved.registry == registry.crawl() + + def test_anchor_in_nonexistent_resource(self): + registry = Registry() + with pytest.raises(exceptions.NoSuchResource) as e: + registry.anchor("urn:example", "foo") + assert e.value == exceptions.NoSuchResource(ref="urn:example") + + def test_init(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + registry = Registry( + { + "http://example.com/1": one, + "http://example.com/foo/bar": two, + }, + ) + assert ( + registry + == Registry() + .with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/foo/bar", two), + ], + ) + .crawl() + ) + + def test_dict_conversion(self): + """ + Passing a `dict` to `Registry` gets converted to a `HashTrieMap`. + + So continuing to use the registry works. + """ + + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + registry = Registry( + {"http://example.com/1": one}, + ).with_resource("http://example.com/foo/bar", two) + assert ( + registry.crawl() + == Registry() + .with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/foo/bar", two), + ], + ) + .crawl() + ) + + def test_no_such_resource(self): + registry = Registry() + with pytest.raises(exceptions.NoSuchResource) as e: + registry["urn:bigboom"] + assert e.value == exceptions.NoSuchResource(ref="urn:bigboom") + + def test_combine(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + three = ID_AND_CHILDREN.create_resource({"baz": "quux"}) + four = ID_AND_CHILDREN.create_resource({"anchors": {"foo": 12}}) + + first = Registry({"http://example.com/1": one}) + second = Registry().with_resource("http://example.com/foo/bar", two) + third = Registry( + { + "http://example.com/1": one, + "http://example.com/baz": three, + }, + ) + fourth = ( + Registry() + .with_resource( + "http://example.com/foo/quux", + four, + ) + .crawl() + ) + assert first.combine(second, third, fourth) == Registry( + [ + ("http://example.com/1", one), + ("http://example.com/baz", three), + ("http://example.com/foo/quux", four), + ], + anchors=HashTrieMap( + { + ("http://example.com/foo/quux", "foo"): Anchor( + name="foo", + resource=ID_AND_CHILDREN.create_resource(12), + ), + }, + ), + ).with_resource("http://example.com/foo/bar", two) + + def test_combine_self(self): + """ + Combining a registry with itself short-circuits. + + This is a performance optimization -- otherwise we do lots more work + (in jsonschema this seems to correspond to making the test suite take + *3x* longer). + """ + + registry = Registry({"urn:foo": "bar"}) + assert registry.combine(registry) is registry + + def test_combine_with_uncrawled_resources(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + three = ID_AND_CHILDREN.create_resource({"baz": "quux"}) + + first = Registry().with_resource("http://example.com/1", one) + second = Registry().with_resource("http://example.com/foo/bar", two) + third = Registry( + { + "http://example.com/1": one, + "http://example.com/baz": three, + }, + ) + expected = Registry( + [ + ("http://example.com/1", one), + ("http://example.com/foo/bar", two), + ("http://example.com/baz", three), + ], + ) + combined = first.combine(second, third) + assert combined != expected + assert combined.crawl() == expected + + def test_combine_with_single_retrieve(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + three = ID_AND_CHILDREN.create_resource({"baz": "quux"}) + + def retrieve(uri): # pragma: no cover + pass + + first = Registry().with_resource("http://example.com/1", one) + second = Registry( + retrieve=retrieve, + ).with_resource("http://example.com/2", two) + third = Registry().with_resource("http://example.com/3", three) + + assert first.combine(second, third) == Registry( + retrieve=retrieve, + ).with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/2", two), + ("http://example.com/3", three), + ], + ) + assert second.combine(first, third) == Registry( + retrieve=retrieve, + ).with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/2", two), + ("http://example.com/3", three), + ], + ) + + def test_combine_with_common_retrieve(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + three = ID_AND_CHILDREN.create_resource({"baz": "quux"}) + + def retrieve(uri): # pragma: no cover + pass + + first = Registry(retrieve=retrieve).with_resource( + "http://example.com/1", + one, + ) + second = Registry( + retrieve=retrieve, + ).with_resource("http://example.com/2", two) + third = Registry(retrieve=retrieve).with_resource( + "http://example.com/3", + three, + ) + + assert first.combine(second, third) == Registry( + retrieve=retrieve, + ).with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/2", two), + ("http://example.com/3", three), + ], + ) + assert second.combine(first, third) == Registry( + retrieve=retrieve, + ).with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/2", two), + ("http://example.com/3", three), + ], + ) + + def test_combine_conflicting_retrieve(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + three = ID_AND_CHILDREN.create_resource({"baz": "quux"}) + + def foo_retrieve(uri): # pragma: no cover + pass + + def bar_retrieve(uri): # pragma: no cover + pass + + first = Registry(retrieve=foo_retrieve).with_resource( + "http://example.com/1", + one, + ) + second = Registry().with_resource("http://example.com/2", two) + third = Registry(retrieve=bar_retrieve).with_resource( + "http://example.com/3", + three, + ) + + with pytest.raises(Exception, match="conflict.*retriev"): + first.combine(second, third) + + def test_remove(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + registry = Registry({"urn:foo": one, "urn:bar": two}) + assert registry.remove("urn:foo") == Registry({"urn:bar": two}) + + def test_remove_uncrawled(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + registry = Registry().with_resources( + [("urn:foo", one), ("urn:bar", two)], + ) + assert registry.remove("urn:foo") == Registry().with_resource( + "urn:bar", + two, + ) + + def test_remove_with_anchors(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"anchors": {"foo": "bar"}}) + registry = ( + Registry() + .with_resources( + [("urn:foo", one), ("urn:bar", two)], + ) + .crawl() + ) + assert ( + registry.remove("urn:bar") + == Registry() + .with_resource( + "urn:foo", + one, + ) + .crawl() + ) + + def test_remove_nonexistent_uri(self): + with pytest.raises(exceptions.NoSuchResource) as e: + Registry().remove("urn:doesNotExist") + assert e.value == exceptions.NoSuchResource(ref="urn:doesNotExist") + + def test_retrieve(self): + foo = Resource.opaque({"foo": "bar"}) + registry = Registry(retrieve=lambda uri: foo) + assert registry.get_or_retrieve("urn:example").value == foo + + def test_retrieve_arbitrary_exception(self): + foo = Resource.opaque({"foo": "bar"}) + + def retrieve(uri): + if uri == "urn:succeed": + return foo + raise Exception("Oh no!") + + registry = Registry(retrieve=retrieve) + assert registry.get_or_retrieve("urn:succeed").value == foo + with pytest.raises(exceptions.Unretrievable): + registry.get_or_retrieve("urn:uhoh") + + def test_retrieve_no_such_resource(self): + foo = Resource.opaque({"foo": "bar"}) + + def retrieve(uri): + if uri == "urn:succeed": + return foo + raise exceptions.NoSuchResource(ref=uri) + + registry = Registry(retrieve=retrieve) + assert registry.get_or_retrieve("urn:succeed").value == foo + with pytest.raises(exceptions.NoSuchResource): + registry.get_or_retrieve("urn:uhoh") + + def test_retrieve_cannot_determine_specification(self): + def retrieve(uri): + return Resource.from_contents({}) + + registry = Registry(retrieve=retrieve) + with pytest.raises(exceptions.CannotDetermineSpecification): + registry.get_or_retrieve("urn:uhoh") + + def test_retrieve_already_available_resource(self): + foo = Resource.opaque({"foo": "bar"}) + registry = Registry({"urn:example": foo}, retrieve=blow_up) + assert registry["urn:example"] == foo + assert registry.get_or_retrieve("urn:example").value == foo + + def test_retrieve_first_checks_crawlable_resource(self): + child = ID_AND_CHILDREN.create_resource({"ID": "urn:child", "foo": 12}) + root = ID_AND_CHILDREN.create_resource({"children": [child.contents]}) + registry = Registry(retrieve=blow_up).with_resource("urn:root", root) + assert registry.crawl()["urn:child"] == child + + def test_resolver(self): + one = Resource.opaque(contents={}) + registry = Registry({"http://example.com": one}) + resolver = registry.resolver(base_uri="http://example.com") + assert resolver.lookup("#").contents == {} + + def test_resolver_with_root_identified(self): + root = ID_AND_CHILDREN.create_resource({"ID": "http://example.com"}) + resolver = Registry().resolver_with_root(root) + assert resolver.lookup("http://example.com").contents == root.contents + assert resolver.lookup("#").contents == root.contents + + def test_resolver_with_root_unidentified(self): + root = Resource.opaque(contents={}) + resolver = Registry().resolver_with_root(root) + assert resolver.lookup("#").contents == root.contents + + def test_repr(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + registry = Registry().with_resources( + [ + ("http://example.com/1", one), + ("http://example.com/foo/bar", two), + ], + ) + assert repr(registry) == "" + assert repr(registry.crawl()) == "" + + def test_repr_mixed_crawled(self): + one = Resource.opaque(contents={}) + two = ID_AND_CHILDREN.create_resource({"foo": "bar"}) + registry = ( + Registry( + {"http://example.com/1": one}, + ) + .crawl() + .with_resource(uri="http://example.com/foo/bar", resource=two) + ) + assert repr(registry) == "" + + def test_repr_one_resource(self): + registry = Registry().with_resource( + uri="http://example.com/1", + resource=Resource.opaque(contents={}), + ) + assert repr(registry) == "" + + def test_repr_empty(self): + assert repr(Registry()) == "" + + +class TestResource: + def test_from_contents_from_json_schema(self): + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + resource = Resource.from_contents(schema) + assert resource == Resource(contents=schema, specification=DRAFT202012) + + def test_from_contents_with_no_discernible_information(self): + """ + Creating a resource with no discernible way to see what + specification it belongs to (e.g. no ``$schema`` keyword for JSON + Schema) raises an error. + """ + + with pytest.raises(exceptions.CannotDetermineSpecification): + Resource.from_contents({"foo": "bar"}) + + def test_from_contents_with_no_discernible_information_and_default(self): + resource = Resource.from_contents( + {"foo": "bar"}, + default_specification=Specification.OPAQUE, + ) + assert resource == Resource.opaque(contents={"foo": "bar"}) + + def test_from_contents_unneeded_default(self): + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + resource = Resource.from_contents( + schema, + default_specification=Specification.OPAQUE, + ) + assert resource == Resource( + contents=schema, + specification=DRAFT202012, + ) + + def test_non_mapping_from_contents(self): + resource = Resource.from_contents( + True, + default_specification=ID_AND_CHILDREN, + ) + assert resource == Resource( + contents=True, + specification=ID_AND_CHILDREN, + ) + + def test_from_contents_with_fallback(self): + resource = Resource.from_contents( + {"foo": "bar"}, + default_specification=Specification.OPAQUE, + ) + assert resource == Resource.opaque(contents={"foo": "bar"}) + + def test_id_delegates_to_specification(self): + specification = Specification( + name="", + id_of=lambda contents: "urn:fixedID", + subresources_of=lambda contents: [], + anchors_in=lambda specification, contents: [], + maybe_in_subresource=( + lambda segments, resolver, subresource: resolver + ), + ) + resource = Resource( + contents={"foo": "baz"}, + specification=specification, + ) + assert resource.id() == "urn:fixedID" + + def test_id_strips_empty_fragment(self): + uri = "http://example.com/" + root = ID_AND_CHILDREN.create_resource({"ID": uri + "#"}) + assert root.id() == uri + + def test_subresources_delegates_to_specification(self): + resource = ID_AND_CHILDREN.create_resource({"children": [{}, 12]}) + assert list(resource.subresources()) == [ + ID_AND_CHILDREN.create_resource(each) for each in [{}, 12] + ] + + def test_subresource_with_different_specification(self): + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + resource = ID_AND_CHILDREN.create_resource({"children": [schema]}) + assert list(resource.subresources()) == [ + DRAFT202012.create_resource(schema), + ] + + def test_anchors_delegates_to_specification(self): + resource = ID_AND_CHILDREN.create_resource( + {"anchors": {"foo": {}, "bar": 1, "baz": ""}}, + ) + assert list(resource.anchors()) == [ + Anchor(name="foo", resource=ID_AND_CHILDREN.create_resource({})), + Anchor(name="bar", resource=ID_AND_CHILDREN.create_resource(1)), + Anchor(name="baz", resource=ID_AND_CHILDREN.create_resource("")), + ] + + def test_pointer_to_mapping(self): + resource = Resource.opaque(contents={"foo": "baz"}) + resolver = Registry().resolver() + assert resource.pointer("/foo", resolver=resolver).contents == "baz" + + def test_pointer_to_array(self): + resource = Resource.opaque(contents={"foo": {"bar": [3]}}) + resolver = Registry().resolver() + assert resource.pointer("/foo/bar/0", resolver=resolver).contents == 3 + + def test_root_pointer(self): + contents = {"foo": "baz"} + resource = Resource.opaque(contents=contents) + resolver = Registry().resolver() + assert resource.pointer("", resolver=resolver).contents == contents + + def test_opaque(self): + contents = {"foo": "bar"} + assert Resource.opaque(contents) == Resource( + contents=contents, + specification=Specification.OPAQUE, + ) + + +class TestResolver: + def test_lookup_exact_uri(self): + resource = Resource.opaque(contents={"foo": "baz"}) + resolver = Registry({"http://example.com/1": resource}).resolver() + resolved = resolver.lookup("http://example.com/1") + assert resolved.contents == resource.contents + + def test_lookup_subresource(self): + root = ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/", + "children": [ + {"ID": "http://example.com/a", "foo": 12}, + ], + }, + ) + registry = root @ Registry() + resolved = registry.resolver().lookup("http://example.com/a") + assert resolved.contents == {"ID": "http://example.com/a", "foo": 12} + + def test_lookup_anchor_with_id(self): + root = ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/", + "anchors": {"foo": 12}, + }, + ) + registry = root @ Registry() + resolved = registry.resolver().lookup("http://example.com/#foo") + assert resolved.contents == 12 + + def test_lookup_anchor_without_id(self): + root = ID_AND_CHILDREN.create_resource({"anchors": {"foo": 12}}) + resolver = Registry().with_resource("urn:example", root).resolver() + resolved = resolver.lookup("urn:example#foo") + assert resolved.contents == 12 + + def test_lookup_unknown_reference(self): + resolver = Registry().resolver() + ref = "http://example.com/does/not/exist" + with pytest.raises(exceptions.Unresolvable) as e: + resolver.lookup(ref) + assert e.value == exceptions.Unresolvable(ref=ref) + + def test_lookup_non_existent_pointer(self): + resource = Resource.opaque({"foo": {}}) + resolver = Registry({"http://example.com/1": resource}).resolver() + ref = "http://example.com/1#/foo/bar" + with pytest.raises(exceptions.Unresolvable) as e: + resolver.lookup(ref) + assert e.value == exceptions.PointerToNowhere( + ref="/foo/bar", + resource=resource, + ) + assert str(e.value) == "'/foo/bar' does not exist within {'foo': {}}" + + def test_lookup_non_existent_pointer_to_array_index(self): + resource = Resource.opaque([1, 2, 4, 8]) + resolver = Registry({"http://example.com/1": resource}).resolver() + ref = "http://example.com/1#/10" + with pytest.raises(exceptions.Unresolvable) as e: + resolver.lookup(ref) + assert e.value == exceptions.PointerToNowhere( + ref="/10", + resource=resource, + ) + + def test_lookup_pointer_to_empty_string(self): + resolver = Registry().resolver_with_root(Resource.opaque({"": {}})) + assert resolver.lookup("#/").contents == {} + + def test_lookup_non_existent_pointer_to_empty_string(self): + resource = Resource.opaque({"foo": {}}) + resolver = Registry().resolver_with_root(resource) + with pytest.raises( + exceptions.Unresolvable, + match="^'/' does not exist within {'foo': {}}.*'#'", + ) as e: + resolver.lookup("#/") + assert e.value == exceptions.PointerToNowhere( + ref="/", + resource=resource, + ) + + def test_lookup_non_existent_anchor(self): + root = ID_AND_CHILDREN.create_resource({"anchors": {}}) + resolver = Registry().with_resource("urn:example", root).resolver() + resolved = resolver.lookup("urn:example") + assert resolved.contents == root.contents + + ref = "urn:example#noSuchAnchor" + with pytest.raises(exceptions.Unresolvable) as e: + resolver.lookup(ref) + assert "'noSuchAnchor' does not exist" in str(e.value) + assert e.value == exceptions.NoSuchAnchor( + ref="urn:example", + resource=root, + anchor="noSuchAnchor", + ) + + def test_lookup_invalid_JSON_pointerish_anchor(self): + resolver = Registry().resolver_with_root( + ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/", + "foo": {"bar": 12}, + }, + ), + ) + + valid = resolver.lookup("#/foo/bar") + assert valid.contents == 12 + + with pytest.raises(exceptions.InvalidAnchor) as e: + resolver.lookup("#foo/bar") + assert " '#/foo/bar'" in str(e.value) + + def test_lookup_retrieved_resource(self): + resource = Resource.opaque(contents={"foo": "baz"}) + resolver = Registry(retrieve=lambda uri: resource).resolver() + resolved = resolver.lookup("http://example.com/") + assert resolved.contents == resource.contents + + def test_lookup_failed_retrieved_resource(self): + """ + Unretrievable exceptions are also wrapped in Unresolvable. + """ + + uri = "http://example.com/" + + registry = Registry(retrieve=blow_up) + with pytest.raises(exceptions.Unretrievable): + registry.get_or_retrieve(uri) + + resolver = registry.resolver() + with pytest.raises(exceptions.Unresolvable): + resolver.lookup(uri) + + def test_repeated_lookup_from_retrieved_resource(self): + """ + A (custom-)retrieved resource is added to the registry returned by + looking it up. + """ + resource = Resource.opaque(contents={"foo": "baz"}) + once = [resource] + + def retrieve(uri): + return once.pop() + + resolver = Registry(retrieve=retrieve).resolver() + resolved = resolver.lookup("http://example.com/") + assert resolved.contents == resource.contents + + resolved = resolved.resolver.lookup("http://example.com/") + assert resolved.contents == resource.contents + + def test_repeated_anchor_lookup_from_retrieved_resource(self): + resource = Resource.opaque(contents={"foo": "baz"}) + once = [resource] + + def retrieve(uri): + return once.pop() + + resolver = Registry(retrieve=retrieve).resolver() + resolved = resolver.lookup("http://example.com/") + assert resolved.contents == resource.contents + + resolved = resolved.resolver.lookup("#") + assert resolved.contents == resource.contents + + # FIXME: The tests below aren't really representable in the current + # suite, though we should probably think of ways to do so. + + def test_in_subresource(self): + root = ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/", + "children": [ + { + "ID": "child/", + "children": [{"ID": "grandchild"}], + }, + ], + }, + ) + registry = root @ Registry() + + resolver = registry.resolver() + first = resolver.lookup("http://example.com/") + assert first.contents == root.contents + + with pytest.raises(exceptions.Unresolvable): + first.resolver.lookup("grandchild") + + sub = first.resolver.in_subresource( + ID_AND_CHILDREN.create_resource(first.contents["children"][0]), + ) + second = sub.lookup("grandchild") + assert second.contents == {"ID": "grandchild"} + + def test_in_pointer_subresource(self): + root = ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/", + "children": [ + { + "ID": "child/", + "children": [{"ID": "grandchild"}], + }, + ], + }, + ) + registry = root @ Registry() + + resolver = registry.resolver() + first = resolver.lookup("http://example.com/") + assert first.contents == root.contents + + with pytest.raises(exceptions.Unresolvable): + first.resolver.lookup("grandchild") + + second = first.resolver.lookup("#/children/0") + third = second.resolver.lookup("grandchild") + assert third.contents == {"ID": "grandchild"} + + def test_dynamic_scope(self): + one = ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/", + "children": [ + { + "ID": "child/", + "children": [{"ID": "grandchild"}], + }, + ], + }, + ) + two = ID_AND_CHILDREN.create_resource( + { + "ID": "http://example.com/two", + "children": [{"ID": "two-child/"}], + }, + ) + registry = [one, two] @ Registry() + + resolver = registry.resolver() + first = resolver.lookup("http://example.com/") + second = first.resolver.lookup("#/children/0") + third = second.resolver.lookup("grandchild") + fourth = third.resolver.lookup("http://example.com/two") + assert list(fourth.resolver.dynamic_scope()) == [ + ("http://example.com/child/grandchild", fourth.resolver._registry), + ("http://example.com/child/", fourth.resolver._registry), + ("http://example.com/", fourth.resolver._registry), + ] + assert list(third.resolver.dynamic_scope()) == [ + ("http://example.com/child/", third.resolver._registry), + ("http://example.com/", third.resolver._registry), + ] + assert list(second.resolver.dynamic_scope()) == [ + ("http://example.com/", second.resolver._registry), + ] + assert list(first.resolver.dynamic_scope()) == [] + + +class TestSpecification: + def test_create_resource(self): + specification = Specification( + name="", + id_of=lambda contents: "urn:fixedID", + subresources_of=lambda contents: [], + anchors_in=lambda specification, contents: [], + maybe_in_subresource=( + lambda segments, resolver, subresource: resolver + ), + ) + resource = specification.create_resource(contents={"foo": "baz"}) + assert resource == Resource( + contents={"foo": "baz"}, + specification=specification, + ) + assert resource.id() == "urn:fixedID" + + def test_detect_from_json_schema(self): + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + specification = Specification.detect(schema) + assert specification == DRAFT202012 + + def test_detect_with_no_discernible_information(self): + with pytest.raises(exceptions.CannotDetermineSpecification): + Specification.detect({"foo": "bar"}) + + def test_detect_with_non_URI_schema(self): + with pytest.raises(exceptions.CannotDetermineSpecification): + Specification.detect({"$schema": 37}) + + def test_detect_with_no_discernible_information_and_default(self): + specification = Specification.OPAQUE.detect({"foo": "bar"}) + assert specification is Specification.OPAQUE + + def test_detect_unneeded_default(self): + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + specification = Specification.OPAQUE.detect(schema) + assert specification == DRAFT202012 + + def test_non_mapping_detect(self): + with pytest.raises(exceptions.CannotDetermineSpecification): + Specification.detect(True) + + def test_non_mapping_detect_with_default(self): + specification = ID_AND_CHILDREN.detect(True) + assert specification is ID_AND_CHILDREN + + def test_detect_with_fallback(self): + specification = Specification.OPAQUE.detect({"foo": "bar"}) + assert specification is Specification.OPAQUE + + def test_repr(self): + assert ( + repr(ID_AND_CHILDREN) == "" + ) + + +class TestOpaqueSpecification: + THINGS = [{"foo": "bar"}, True, 37, "foo", object()] + + @pytest.mark.parametrize("thing", THINGS) + def test_no_id(self, thing): + """ + An arbitrary thing has no ID. + """ + + assert Specification.OPAQUE.id_of(thing) is None + + @pytest.mark.parametrize("thing", THINGS) + def test_no_subresources(self, thing): + """ + An arbitrary thing has no subresources. + """ + + assert list(Specification.OPAQUE.subresources_of(thing)) == [] + + @pytest.mark.parametrize("thing", THINGS) + def test_no_anchors(self, thing): + """ + An arbitrary thing has no anchors. + """ + + assert list(Specification.OPAQUE.anchors_in(thing)) == [] + + +@pytest.mark.parametrize( + "cls", + [Anchor, Registry, Resource, Specification, exceptions.PointerToNowhere], +) +def test_nonsubclassable(cls): + with pytest.raises(Exception, match="(?i)subclassing"): + + class Boom(cls): # pragma: no cover + pass diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/test_exceptions.py b/.venv/lib/python3.11/site-packages/referencing/tests/test_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..85cf99ecdd86c86e84df0b64f24aec6c447f4c08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/tests/test_exceptions.py @@ -0,0 +1,34 @@ +import itertools + +import pytest + +from referencing import Resource, exceptions + + +def pairs(choices): + return itertools.combinations(choices, 2) + + +TRUE = Resource.opaque(True) + + +thunks = ( + lambda: exceptions.CannotDetermineSpecification(TRUE), + lambda: exceptions.NoSuchResource("urn:example:foo"), + lambda: exceptions.NoInternalID(TRUE), + lambda: exceptions.InvalidAnchor(resource=TRUE, anchor="foo", ref="a#b"), + lambda: exceptions.NoSuchAnchor(resource=TRUE, anchor="foo", ref="a#b"), + lambda: exceptions.PointerToNowhere(resource=TRUE, ref="urn:example:foo"), + lambda: exceptions.Unresolvable("urn:example:foo"), + lambda: exceptions.Unretrievable("urn:example:foo"), +) + + +@pytest.mark.parametrize("one, two", pairs(each() for each in thunks)) +def test_eq_incompatible_types(one, two): + assert one != two + + +@pytest.mark.parametrize("thunk", thunks) +def test_hash(thunk): + assert thunk() in {thunk()} diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/test_jsonschema.py b/.venv/lib/python3.11/site-packages/referencing/tests/test_jsonschema.py new file mode 100644 index 0000000000000000000000000000000000000000..c80714d0132bebbec33401f42a2e06aee3fed9c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/tests/test_jsonschema.py @@ -0,0 +1,382 @@ +import pytest + +from referencing import Registry, Resource, Specification +import referencing.jsonschema + + +@pytest.mark.parametrize( + "uri, expected", + [ + ( + "https://json-schema.org/draft/2020-12/schema", + referencing.jsonschema.DRAFT202012, + ), + ( + "https://json-schema.org/draft/2019-09/schema", + referencing.jsonschema.DRAFT201909, + ), + ( + "http://json-schema.org/draft-07/schema#", + referencing.jsonschema.DRAFT7, + ), + ( + "http://json-schema.org/draft-06/schema#", + referencing.jsonschema.DRAFT6, + ), + ( + "http://json-schema.org/draft-04/schema#", + referencing.jsonschema.DRAFT4, + ), + ( + "http://json-schema.org/draft-03/schema#", + referencing.jsonschema.DRAFT3, + ), + ], +) +def test_schemas_with_explicit_schema_keywords_are_detected(uri, expected): + """ + The $schema keyword in JSON Schema is a dialect identifier. + """ + contents = {"$schema": uri} + resource = Resource.from_contents(contents) + assert resource == Resource(contents=contents, specification=expected) + + +def test_unknown_dialect(): + dialect_id = "http://example.com/unknown-json-schema-dialect-id" + with pytest.raises(referencing.jsonschema.UnknownDialect) as excinfo: + Resource.from_contents({"$schema": dialect_id}) + assert excinfo.value.uri == dialect_id + + +@pytest.mark.parametrize( + "id, specification", + [ + ("$id", referencing.jsonschema.DRAFT202012), + ("$id", referencing.jsonschema.DRAFT201909), + ("$id", referencing.jsonschema.DRAFT7), + ("$id", referencing.jsonschema.DRAFT6), + ("id", referencing.jsonschema.DRAFT4), + ("id", referencing.jsonschema.DRAFT3), + ], +) +def test_id_of_mapping(id, specification): + uri = "http://example.com/some-schema" + assert specification.id_of({id: uri}) == uri + + +@pytest.mark.parametrize( + "specification", + [ + referencing.jsonschema.DRAFT202012, + referencing.jsonschema.DRAFT201909, + referencing.jsonschema.DRAFT7, + referencing.jsonschema.DRAFT6, + ], +) +@pytest.mark.parametrize("value", [True, False]) +def test_id_of_bool(specification, value): + assert specification.id_of(value) is None + + +@pytest.mark.parametrize( + "specification", + [ + referencing.jsonschema.DRAFT202012, + referencing.jsonschema.DRAFT201909, + referencing.jsonschema.DRAFT7, + referencing.jsonschema.DRAFT6, + ], +) +@pytest.mark.parametrize("value", [True, False]) +def test_anchors_in_bool(specification, value): + assert list(specification.anchors_in(value)) == [] + + +@pytest.mark.parametrize( + "specification", + [ + referencing.jsonschema.DRAFT202012, + referencing.jsonschema.DRAFT201909, + referencing.jsonschema.DRAFT7, + referencing.jsonschema.DRAFT6, + ], +) +@pytest.mark.parametrize("value", [True, False]) +def test_subresources_of_bool(specification, value): + assert list(specification.subresources_of(value)) == [] + + +@pytest.mark.parametrize( + "uri, expected", + [ + ( + "https://json-schema.org/draft/2020-12/schema", + referencing.jsonschema.DRAFT202012, + ), + ( + "https://json-schema.org/draft/2019-09/schema", + referencing.jsonschema.DRAFT201909, + ), + ( + "http://json-schema.org/draft-07/schema#", + referencing.jsonschema.DRAFT7, + ), + ( + "http://json-schema.org/draft-06/schema#", + referencing.jsonschema.DRAFT6, + ), + ( + "http://json-schema.org/draft-04/schema#", + referencing.jsonschema.DRAFT4, + ), + ( + "http://json-schema.org/draft-03/schema#", + referencing.jsonschema.DRAFT3, + ), + ], +) +def test_specification_with(uri, expected): + assert referencing.jsonschema.specification_with(uri) == expected + + +@pytest.mark.parametrize( + "uri, expected", + [ + ( + "http://json-schema.org/draft-07/schema", + referencing.jsonschema.DRAFT7, + ), + ( + "http://json-schema.org/draft-06/schema", + referencing.jsonschema.DRAFT6, + ), + ( + "http://json-schema.org/draft-04/schema", + referencing.jsonschema.DRAFT4, + ), + ( + "http://json-schema.org/draft-03/schema", + referencing.jsonschema.DRAFT3, + ), + ], +) +def test_specification_with_no_empty_fragment(uri, expected): + assert referencing.jsonschema.specification_with(uri) == expected + + +def test_specification_with_unknown_dialect(): + dialect_id = "http://example.com/unknown-json-schema-dialect-id" + with pytest.raises(referencing.jsonschema.UnknownDialect) as excinfo: + referencing.jsonschema.specification_with(dialect_id) + assert excinfo.value.uri == dialect_id + + +def test_specification_with_default(): + dialect_id = "http://example.com/unknown-json-schema-dialect-id" + specification = referencing.jsonschema.specification_with( + dialect_id, + default=Specification.OPAQUE, + ) + assert specification is Specification.OPAQUE + + +# FIXME: The tests below should move to the referencing suite but I haven't yet +# figured out how to represent dynamic (& recursive) ref lookups in it. +def test_lookup_trivial_dynamic_ref(): + one = referencing.jsonschema.DRAFT202012.create_resource( + {"$dynamicAnchor": "foo"}, + ) + resolver = Registry().with_resource("http://example.com", one).resolver() + resolved = resolver.lookup("http://example.com#foo") + assert resolved.contents == one.contents + + +def test_multiple_lookup_trivial_dynamic_ref(): + TRUE = referencing.jsonschema.DRAFT202012.create_resource(True) + root = referencing.jsonschema.DRAFT202012.create_resource( + { + "$id": "http://example.com", + "$dynamicAnchor": "fooAnchor", + "$defs": { + "foo": { + "$id": "foo", + "$dynamicAnchor": "fooAnchor", + "$defs": { + "bar": True, + "baz": { + "$dynamicAnchor": "fooAnchor", + }, + }, + }, + }, + }, + ) + resolver = ( + Registry() + .with_resources( + [ + ("http://example.com", root), + ("http://example.com/foo/", TRUE), + ("http://example.com/foo/bar", root), + ], + ) + .resolver() + ) + + first = resolver.lookup("http://example.com") + second = first.resolver.lookup("foo/") + resolver = second.resolver.lookup("bar").resolver + fourth = resolver.lookup("#fooAnchor") + assert fourth.contents == root.contents + + +def test_multiple_lookup_dynamic_ref_to_nondynamic_ref(): + one = referencing.jsonschema.DRAFT202012.create_resource( + {"$anchor": "fooAnchor"}, + ) + two = referencing.jsonschema.DRAFT202012.create_resource( + { + "$id": "http://example.com", + "$dynamicAnchor": "fooAnchor", + "$defs": { + "foo": { + "$id": "foo", + "$dynamicAnchor": "fooAnchor", + "$defs": { + "bar": True, + "baz": { + "$dynamicAnchor": "fooAnchor", + }, + }, + }, + }, + }, + ) + resolver = ( + Registry() + .with_resources( + [ + ("http://example.com", two), + ("http://example.com/foo/", one), + ("http://example.com/foo/bar", two), + ], + ) + .resolver() + ) + + first = resolver.lookup("http://example.com") + second = first.resolver.lookup("foo/") + resolver = second.resolver.lookup("bar").resolver + fourth = resolver.lookup("#fooAnchor") + assert fourth.contents == two.contents + + +def test_lookup_trivial_recursive_ref(): + one = referencing.jsonschema.DRAFT201909.create_resource( + {"$recursiveAnchor": True}, + ) + resolver = Registry().with_resource("http://example.com", one).resolver() + first = resolver.lookup("http://example.com") + resolved = referencing.jsonschema.lookup_recursive_ref( + resolver=first.resolver, + ) + assert resolved.contents == one.contents + + +def test_lookup_recursive_ref_to_bool(): + TRUE = referencing.jsonschema.DRAFT201909.create_resource(True) + registry = Registry({"http://example.com": TRUE}) + resolved = referencing.jsonschema.lookup_recursive_ref( + resolver=registry.resolver(base_uri="http://example.com"), + ) + assert resolved.contents == TRUE.contents + + +def test_multiple_lookup_recursive_ref_to_bool(): + TRUE = referencing.jsonschema.DRAFT201909.create_resource(True) + root = referencing.jsonschema.DRAFT201909.create_resource( + { + "$id": "http://example.com", + "$recursiveAnchor": True, + "$defs": { + "foo": { + "$id": "foo", + "$recursiveAnchor": True, + "$defs": { + "bar": True, + "baz": { + "$recursiveAnchor": True, + "$anchor": "fooAnchor", + }, + }, + }, + }, + }, + ) + resolver = ( + Registry() + .with_resources( + [ + ("http://example.com", root), + ("http://example.com/foo/", TRUE), + ("http://example.com/foo/bar", root), + ], + ) + .resolver() + ) + + first = resolver.lookup("http://example.com") + second = first.resolver.lookup("foo/") + resolver = second.resolver.lookup("bar").resolver + fourth = referencing.jsonschema.lookup_recursive_ref(resolver=resolver) + assert fourth.contents == root.contents + + +def test_multiple_lookup_recursive_ref_with_nonrecursive_ref(): + one = referencing.jsonschema.DRAFT201909.create_resource( + {"$recursiveAnchor": True}, + ) + two = referencing.jsonschema.DRAFT201909.create_resource( + { + "$id": "http://example.com", + "$recursiveAnchor": True, + "$defs": { + "foo": { + "$id": "foo", + "$recursiveAnchor": True, + "$defs": { + "bar": True, + "baz": { + "$recursiveAnchor": True, + "$anchor": "fooAnchor", + }, + }, + }, + }, + }, + ) + three = referencing.jsonschema.DRAFT201909.create_resource( + {"$recursiveAnchor": False}, + ) + resolver = ( + Registry() + .with_resources( + [ + ("http://example.com", three), + ("http://example.com/foo/", two), + ("http://example.com/foo/bar", one), + ], + ) + .resolver() + ) + + first = resolver.lookup("http://example.com") + second = first.resolver.lookup("foo/") + resolver = second.resolver.lookup("bar").resolver + fourth = referencing.jsonschema.lookup_recursive_ref(resolver=resolver) + assert fourth.contents == two.contents + + +def test_empty_registry(): + assert referencing.jsonschema.EMPTY_REGISTRY == Registry() diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/test_referencing_suite.py b/.venv/lib/python3.11/site-packages/referencing/tests/test_referencing_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8ae9177c197456bb3bbd62d4c1875bc95ff28b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/tests/test_referencing_suite.py @@ -0,0 +1,66 @@ +from pathlib import Path +import json +import os + +import pytest + +from referencing import Registry +from referencing.exceptions import Unresolvable +import referencing.jsonschema + + +class SuiteNotFound(Exception): + def __str__(self): # pragma: no cover + return ( + "Cannot find the referencing suite. " + "Set the REFERENCING_SUITE environment variable to the path to " + "the suite, or run the test suite from alongside a full checkout " + "of the git repository." + ) + + +if "REFERENCING_SUITE" in os.environ: # pragma: no cover + SUITE = Path(os.environ["REFERENCING_SUITE"]) / "tests" +else: + SUITE = Path(__file__).parent.parent.parent / "suite/tests" +if not SUITE.is_dir(): # pragma: no cover + raise SuiteNotFound() +DIALECT_IDS = json.loads(SUITE.joinpath("specifications.json").read_text()) + + +@pytest.mark.parametrize( + "test_path", + [ + pytest.param(each, id=f"{each.parent.name}-{each.stem}") + for each in SUITE.glob("*/**/*.json") + ], +) +def test_referencing_suite(test_path, subtests): + dialect_id = DIALECT_IDS[test_path.relative_to(SUITE).parts[0]] + specification = referencing.jsonschema.specification_with(dialect_id) + loaded = json.loads(test_path.read_text()) + registry = loaded["registry"] + registry = Registry().with_resources( + (uri, specification.create_resource(contents)) + for uri, contents in loaded["registry"].items() + ) + for test in loaded["tests"]: + with subtests.test(test=test): + if "normalization" in test_path.stem: + pytest.xfail("APIs need to change for proper URL support.") + + resolver = registry.resolver(base_uri=test.get("base_uri", "")) + + if test.get("error"): + with pytest.raises(Unresolvable): + resolver.lookup(test["ref"]) + else: + resolved = resolver.lookup(test["ref"]) + assert resolved.contents == test["target"] + + then = test.get("then") + while then: # pragma: no cover + with subtests.test(test=test, then=then): + resolved = resolved.resolver.lookup(then["ref"]) + assert resolved.contents == then["target"] + then = then.get("then") diff --git a/.venv/lib/python3.11/site-packages/referencing/tests/test_retrieval.py b/.venv/lib/python3.11/site-packages/referencing/tests/test_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a8f8ad9975d1a760bca14dea7b60d41fb8ea75 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/tests/test_retrieval.py @@ -0,0 +1,106 @@ +from functools import lru_cache +import json + +import pytest + +from referencing import Registry, Resource, exceptions +from referencing.jsonschema import DRAFT202012 +from referencing.retrieval import to_cached_resource + + +class TestToCachedResource: + def test_it_caches_retrieved_resources(self): + contents = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + stack = [json.dumps(contents)] + + @to_cached_resource() + def retrieve(uri): + return stack.pop() + + registry = Registry(retrieve=retrieve) + + expected = Resource.from_contents(contents) + + got = registry.get_or_retrieve("urn:example:schema") + assert got.value == expected + + # And a second time we get the same value. + again = registry.get_or_retrieve("urn:example:schema") + assert again.value is got.value + + def test_custom_loader(self): + contents = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + stack = [json.dumps(contents)[::-1]] + + @to_cached_resource(loads=lambda s: json.loads(s[::-1])) + def retrieve(uri): + return stack.pop() + + registry = Registry(retrieve=retrieve) + + expected = Resource.from_contents(contents) + + got = registry.get_or_retrieve("urn:example:schema") + assert got.value == expected + + # And a second time we get the same value. + again = registry.get_or_retrieve("urn:example:schema") + assert again.value is got.value + + def test_custom_from_contents(self): + contents = {} + stack = [json.dumps(contents)] + + @to_cached_resource(from_contents=DRAFT202012.create_resource) + def retrieve(uri): + return stack.pop() + + registry = Registry(retrieve=retrieve) + + expected = DRAFT202012.create_resource(contents) + + got = registry.get_or_retrieve("urn:example:schema") + assert got.value == expected + + # And a second time we get the same value. + again = registry.get_or_retrieve("urn:example:schema") + assert again.value is got.value + + def test_custom_cache(self): + schema = {"$schema": "https://json-schema.org/draft/2020-12/schema"} + mapping = { + "urn:example:1": dict(schema, foo=1), + "urn:example:2": dict(schema, foo=2), + "urn:example:3": dict(schema, foo=3), + } + + resources = { + uri: Resource.from_contents(contents) + for uri, contents in mapping.items() + } + + @to_cached_resource(cache=lru_cache(maxsize=2)) + def retrieve(uri): + return json.dumps(mapping.pop(uri)) + + registry = Registry(retrieve=retrieve) + + got = registry.get_or_retrieve("urn:example:1") + assert got.value == resources["urn:example:1"] + assert registry.get_or_retrieve("urn:example:1").value is got.value + assert registry.get_or_retrieve("urn:example:1").value is got.value + + got = registry.get_or_retrieve("urn:example:2") + assert got.value == resources["urn:example:2"] + assert registry.get_or_retrieve("urn:example:2").value is got.value + assert registry.get_or_retrieve("urn:example:2").value is got.value + + # This still succeeds, but evicts the first URI + got = registry.get_or_retrieve("urn:example:3") + assert got.value == resources["urn:example:3"] + assert registry.get_or_retrieve("urn:example:3").value is got.value + assert registry.get_or_retrieve("urn:example:3").value is got.value + + # And now this fails (as we popped the value out of `mapping`) + with pytest.raises(exceptions.Unretrievable): + registry.get_or_retrieve("urn:example:1") diff --git a/.venv/lib/python3.11/site-packages/referencing/typing.py b/.venv/lib/python3.11/site-packages/referencing/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..a61446417e68e8e397346b14e8525cc9062f1c59 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/referencing/typing.py @@ -0,0 +1,61 @@ +""" +Type-annotation related support for the referencing library. +""" + +from __future__ import annotations + +from collections.abc import Mapping as Mapping +from typing import TYPE_CHECKING, Any, Protocol + +try: + from typing_extensions import TypeVar +except ImportError: # pragma: no cover + from typing import TypeVar + +if TYPE_CHECKING: + from referencing._core import Resolved, Resolver, Resource + +#: A URI which identifies a `Resource`. +URI = str + +#: The type of documents within a registry. +D = TypeVar("D", default=Any) + + +class Retrieve(Protocol[D]): + """ + A retrieval callable, usable within a `Registry` for resource retrieval. + + Does not make assumptions about where the resource might be coming from. + """ + + def __call__(self, uri: URI) -> Resource[D]: + """ + Retrieve the resource with the given URI. + + Raise `referencing.exceptions.NoSuchResource` if you wish to indicate + the retriever cannot lookup the given URI. + """ + ... + + +class Anchor(Protocol[D]): + """ + An anchor within a `Resource`. + + Beyond "simple" anchors, some specifications like JSON Schema's 2020 + version have dynamic anchors. + """ + + @property + def name(self) -> str: + """ + Return the name of this anchor. + """ + ... + + def resolve(self, resolver: Resolver[D]) -> Resolved[D]: + """ + Return the resource for this anchor. + """ + ... diff --git a/.venv/lib/python3.11/site-packages/starlette/_exception_handler.py b/.venv/lib/python3.11/site-packages/starlette/_exception_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..72bc89d91a4e7c2400511e280c26dd1cb2fc6502 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/_exception_handler.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send +from starlette.websockets import WebSocket + +ExceptionHandlers = dict[typing.Any, ExceptionHandler] +StatusHandlers = dict[int, ExceptionHandler] + + +def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None: + for cls in type(exc).__mro__: + if cls in exc_handlers: + return exc_handlers[cls] + return None + + +def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp: + exception_handlers: ExceptionHandlers + status_handlers: StatusHandlers + try: + exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await app(scope, receive, sender) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exception_handlers, exc) + + if handler is None: + raise exc + + if response_started: + raise RuntimeError("Caught handled exception, but response already started.") from exc + + if is_async_callable(handler): + response = await handler(conn, exc) + else: + response = await run_in_threadpool(handler, conn, exc) # type: ignore + if response is not None: + await response(scope, receive, sender) + + return wrapped_app diff --git a/.venv/lib/python3.11/site-packages/starlette/applications.py b/.venv/lib/python3.11/site-packages/starlette/applications.py new file mode 100644 index 0000000000000000000000000000000000000000..0a717bb3af031f8b448f7b37df84e661236d89bd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/applications.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import sys +import typing +import warnings + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette.datastructures import State, URLPath +from starlette.middleware import Middleware, _MiddlewareFactory +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.errors import ServerErrorMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import BaseRoute, Router +from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send +from starlette.websockets import WebSocket + +AppType = typing.TypeVar("AppType", bound="Starlette") +P = ParamSpec("P") + + +class Starlette: + """Creates an Starlette application.""" + + def __init__( + self: AppType, + debug: bool = False, + routes: typing.Sequence[BaseRoute] | None = None, + middleware: typing.Sequence[Middleware] | None = None, + exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None, + on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + lifespan: Lifespan[AppType] | None = None, + ) -> None: + """Initializes the application. + + Parameters: + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette + application will always automatically include two middleware classes. + `ServerErrorMiddleware` is added as the very outermost middleware, to handle + any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal + with handled exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, + or exception class types onto callables which handle the exceptions. + Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or + async functions. + on_startup: A list of callables to run on application startup. + Startup handler callables do not take any arguments, and may be either + standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. + Shutdown handler callables do not take any arguments, and may be either + standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform + startup and shutdown tasks. This is a newer style that replaces the + `on_startup` and `on_shutdown` handlers. Use one or the other, not both. + """ + # The lifespan context function is a newer style that replaces + # on_startup / on_shutdown handlers. Use one or the other, not both. + assert lifespan is None or ( + on_startup is None and on_shutdown is None + ), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both." + + self.debug = debug + self.state = State() + self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan) + self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers) + self.user_middleware = [] if middleware is None else list(middleware) + self.middleware_stack: ASGIApp | None = None + + def build_middleware_stack(self) -> ASGIApp: + debug = self.debug + error_handler = None + exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {} + + for key, value in self.exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + middleware = ( + [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + self.user_middleware + + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)] + ) + + app = self.router + for cls, args, kwargs in reversed(middleware): + app = cls(app, *args, **kwargs) + return app + + @property + def routes(self) -> list[BaseRoute]: + return self.router.routes + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + return self.router.url_path_for(name, **path_params) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope["app"] = self + if self.middleware_stack is None: + self.middleware_stack = self.build_middleware_stack() + await self.middleware_stack(scope, receive, send) + + def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + return self.router.on_event(event_type) # pragma: no cover + + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: + self.router.mount(path, app=app, name=name) # pragma: no cover + + def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: + self.router.host(host, app=app, name=name) # pragma: no cover + + def add_middleware( + self, + middleware_class: _MiddlewareFactory[P], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + if self.middleware_stack is not None: # pragma: no cover + raise RuntimeError("Cannot add middleware after an application has started") + self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs)) + + def add_exception_handler( + self, + exc_class_or_status_code: int | type[Exception], + handler: ExceptionHandler, + ) -> None: # pragma: no cover + self.exception_handlers[exc_class_or_status_code] = handler + + def add_event_handler( + self, + event_type: str, + func: typing.Callable, # type: ignore[type-arg] + ) -> None: # pragma: no cover + self.router.add_event_handler(event_type, func) + + def add_route( + self, + path: str, + route: typing.Callable[[Request], typing.Awaitable[Response] | Response], + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> None: # pragma: no cover + self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema) + + def add_websocket_route( + self, + path: str, + route: typing.Callable[[WebSocket], typing.Awaitable[None]], + name: str | None = None, + ) -> None: # pragma: no cover + self.router.add_websocket_route(path, route, name=name) + + def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg] + warnings.warn( + "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_exception_handler(exc_class_or_status_code, func) + return func + + return decorator + + def route( + self, + path: str, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [Route(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `route` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/routing/ for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.router.add_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + return func + + return decorator + + def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [WebSocketRoute(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.router.add_websocket_route(path, func, name=name) + return func + + return decorator + + def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> middleware = [Middleware(...), ...] + >>> app = Starlette(middleware=middleware) + """ + warnings.warn( + "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", + DeprecationWarning, + ) + assert middleware_type == "http", 'Currently only middleware("http") is supported.' + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_middleware(BaseHTTPMiddleware, dispatch=func) + return func + + return decorator diff --git a/.venv/lib/python3.11/site-packages/starlette/concurrency.py b/.venv/lib/python3.11/site-packages/starlette/concurrency.py new file mode 100644 index 0000000000000000000000000000000000000000..494f34204fc004f2be964bd227cb0c7a303d8351 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/concurrency.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import functools +import sys +import typing +import warnings + +import anyio.to_thread + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = typing.TypeVar("T") + + +async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] + warnings.warn( + "run_until_first_complete is deprecated and will be removed in a future version.", + DeprecationWarning, + ) + + async with anyio.create_task_group() as task_group: + + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] + await func() + task_group.cancel_scope.cancel() + + for func, kwargs in args: + task_group.start_soon(run, functools.partial(func, **kwargs)) + + +async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func = functools.partial(func, *args, **kwargs) + return await anyio.to_thread.run_sync(func) + + +class _StopIteration(Exception): + pass + + +def _next(iterator: typing.Iterator[T]) -> T: + # We can't raise `StopIteration` from within the threadpool iterator + # and catch it outside that context, so we coerce them into a different + # exception type. + try: + return next(iterator) + except StopIteration: + raise _StopIteration + + +async def iterate_in_threadpool( + iterator: typing.Iterable[T], +) -> typing.AsyncIterator[T]: + as_iterator = iter(iterator) + while True: + try: + yield await anyio.to_thread.run_sync(_next, as_iterator) + except _StopIteration: + break diff --git a/.venv/lib/python3.11/site-packages/starlette/config.py b/.venv/lib/python3.11/site-packages/starlette/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ca15c564670271e66790000a308a7d7e981e0bac --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/config.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import os +import typing +import warnings +from pathlib import Path + + +class undefined: + pass + + +class EnvironError(Exception): + pass + + +class Environ(typing.MutableMapping[str, str]): + def __init__(self, environ: typing.MutableMapping[str, str] = os.environ): + self._environ = environ + self._has_been_read: set[str] = set() + + def __getitem__(self, key: str) -> str: + self._has_been_read.add(key) + return self._environ.__getitem__(key) + + def __setitem__(self, key: str, value: str) -> None: + if key in self._has_been_read: + raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.") + self._environ.__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + if key in self._has_been_read: + raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.") + self._environ.__delitem__(key) + + def __iter__(self) -> typing.Iterator[str]: + return iter(self._environ) + + def __len__(self) -> int: + return len(self._environ) + + +environ = Environ() + +T = typing.TypeVar("T") + + +class Config: + def __init__( + self, + env_file: str | Path | None = None, + environ: typing.Mapping[str, str] = environ, + env_prefix: str = "", + ) -> None: + self.environ = environ + self.env_prefix = env_prefix + self.file_values: dict[str, str] = {} + if env_file is not None: + if not os.path.isfile(env_file): + warnings.warn(f"Config file '{env_file}' not found.") + else: + self.file_values = self._read_file(env_file) + + @typing.overload + def __call__(self, key: str, *, default: None) -> str | None: ... + + @typing.overload + def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ... + + @typing.overload + def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ... + + @typing.overload + def __call__( + self, + key: str, + cast: typing.Callable[[typing.Any], T] = ..., + default: typing.Any = ..., + ) -> T: ... + + @typing.overload + def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ... + + def __call__( + self, + key: str, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, + default: typing.Any = undefined, + ) -> typing.Any: + return self.get(key, cast, default) + + def get( + self, + key: str, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, + default: typing.Any = undefined, + ) -> typing.Any: + key = self.env_prefix + key + if key in self.environ: + value = self.environ[key] + return self._perform_cast(key, value, cast) + if key in self.file_values: + value = self.file_values[key] + return self._perform_cast(key, value, cast) + if default is not undefined: + return self._perform_cast(key, default, cast) + raise KeyError(f"Config '{key}' is missing, and has no default.") + + def _read_file(self, file_name: str | Path) -> dict[str, str]: + file_values: dict[str, str] = {} + with open(file_name) as input_file: + for line in input_file.readlines(): + line = line.strip() + if "=" in line and not line.startswith("#"): + key, value = line.split("=", 1) + key = key.strip() + value = value.strip().strip("\"'") + file_values[key] = value + return file_values + + def _perform_cast( + self, + key: str, + value: typing.Any, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, + ) -> typing.Any: + if cast is None or value is None: + return value + elif cast is bool and isinstance(value, str): + mapping = {"true": True, "1": True, "false": False, "0": False} + value = value.lower() + if value not in mapping: + raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.") + return mapping[value] + try: + return cast(value) + except (TypeError, ValueError): + raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.") diff --git a/.venv/lib/python3.11/site-packages/starlette/convertors.py b/.venv/lib/python3.11/site-packages/starlette/convertors.py new file mode 100644 index 0000000000000000000000000000000000000000..84df87a586adf44bb6135d547049b9f5385f30b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/convertors.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import math +import typing +import uuid + +T = typing.TypeVar("T") + + +class Convertor(typing.Generic[T]): + regex: typing.ClassVar[str] = "" + + def convert(self, value: str) -> T: + raise NotImplementedError() # pragma: no cover + + def to_string(self, value: T) -> str: + raise NotImplementedError() # pragma: no cover + + +class StringConvertor(Convertor[str]): + regex = "[^/]+" + + def convert(self, value: str) -> str: + return value + + def to_string(self, value: str) -> str: + value = str(value) + assert "/" not in value, "May not contain path separators" + assert value, "Must not be empty" + return value + + +class PathConvertor(Convertor[str]): + regex = ".*" + + def convert(self, value: str) -> str: + return str(value) + + def to_string(self, value: str) -> str: + return str(value) + + +class IntegerConvertor(Convertor[int]): + regex = "[0-9]+" + + def convert(self, value: str) -> int: + return int(value) + + def to_string(self, value: int) -> str: + value = int(value) + assert value >= 0, "Negative integers are not supported" + return str(value) + + +class FloatConvertor(Convertor[float]): + regex = r"[0-9]+(\.[0-9]+)?" + + def convert(self, value: str) -> float: + return float(value) + + def to_string(self, value: float) -> str: + value = float(value) + assert value >= 0.0, "Negative floats are not supported" + assert not math.isnan(value), "NaN values are not supported" + assert not math.isinf(value), "Infinite values are not supported" + return ("%0.20f" % value).rstrip("0").rstrip(".") + + +class UUIDConvertor(Convertor[uuid.UUID]): + regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}" + + def convert(self, value: str) -> uuid.UUID: + return uuid.UUID(value) + + def to_string(self, value: uuid.UUID) -> str: + return str(value) + + +CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = { + "str": StringConvertor(), + "path": PathConvertor(), + "int": IntegerConvertor(), + "float": FloatConvertor(), + "uuid": UUIDConvertor(), +} + + +def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None: + CONVERTOR_TYPES[key] = convertor diff --git a/.venv/lib/python3.11/site-packages/starlette/datastructures.py b/.venv/lib/python3.11/site-packages/starlette/datastructures.py new file mode 100644 index 0000000000000000000000000000000000000000..90a7296a09a63a6993d652988782d991b2154686 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/datastructures.py @@ -0,0 +1,679 @@ +from __future__ import annotations + +import typing +from shlex import shlex +from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit + +from starlette.concurrency import run_in_threadpool +from starlette.types import Scope + + +class Address(typing.NamedTuple): + host: str + port: int + + +_KeyType = typing.TypeVar("_KeyType") +# Mapping keys are invariant but their values are covariant since +# you can only read them +# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` +_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) + + +class URL: + def __init__( + self, + url: str = "", + scope: Scope | None = None, + **components: typing.Any, + ) -> None: + if scope is not None: + assert not url, 'Cannot set both "url" and "scope".' + assert not components, 'Cannot set both "scope" and "**components".' + scheme = scope.get("scheme", "http") + server = scope.get("server", None) + path = scope["path"] + query_string = scope.get("query_string", b"") + + host_header = None + for key, value in scope["headers"]: + if key == b"host": + host_header = value.decode("latin-1") + break + + if host_header is not None: + url = f"{scheme}://{host_header}{path}" + elif server is None: + url = path + else: + host, port = server + default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] + if port == default_port: + url = f"{scheme}://{host}{path}" + else: + url = f"{scheme}://{host}:{port}{path}" + + if query_string: + url += "?" + query_string.decode() + elif components: + assert not url, 'Cannot set both "url" and "**components".' + url = URL("").replace(**components).components.geturl() + + self._url = url + + @property + def components(self) -> SplitResult: + if not hasattr(self, "_components"): + self._components = urlsplit(self._url) + return self._components + + @property + def scheme(self) -> str: + return self.components.scheme + + @property + def netloc(self) -> str: + return self.components.netloc + + @property + def path(self) -> str: + return self.components.path + + @property + def query(self) -> str: + return self.components.query + + @property + def fragment(self) -> str: + return self.components.fragment + + @property + def username(self) -> None | str: + return self.components.username + + @property + def password(self) -> None | str: + return self.components.password + + @property + def hostname(self) -> None | str: + return self.components.hostname + + @property + def port(self) -> int | None: + return self.components.port + + @property + def is_secure(self) -> bool: + return self.scheme in ("https", "wss") + + def replace(self, **kwargs: typing.Any) -> URL: + if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: + hostname = kwargs.pop("hostname", None) + port = kwargs.pop("port", self.port) + username = kwargs.pop("username", self.username) + password = kwargs.pop("password", self.password) + + if hostname is None: + netloc = self.netloc + _, _, hostname = netloc.rpartition("@") + + if hostname[-1] != "]": + hostname = hostname.rsplit(":", 1)[0] + + netloc = hostname + if port is not None: + netloc += f":{port}" + if username is not None: + userpass = username + if password is not None: + userpass += f":{password}" + netloc = f"{userpass}@{netloc}" + + kwargs["netloc"] = netloc + + components = self.components._replace(**kwargs) + return self.__class__(components.geturl()) + + def include_query_params(self, **kwargs: typing.Any) -> URL: + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + params.update({str(key): str(value) for key, value in kwargs.items()}) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def replace_query_params(self, **kwargs: typing.Any) -> URL: + query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) + return self.replace(query=query) + + def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL: + if isinstance(keys, str): + keys = [keys] + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + for key in keys: + params.pop(key, None) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def __eq__(self, other: typing.Any) -> bool: + return str(self) == str(other) + + def __str__(self) -> str: + return self._url + + def __repr__(self) -> str: + url = str(self) + if self.password: + url = str(self.replace(password="********")) + return f"{self.__class__.__name__}({repr(url)})" + + +class URLPath(str): + """ + A URL path string that may also hold an associated protocol and/or host. + Used by the routing to return `url_path_for` matches. + """ + + def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: + assert protocol in ("http", "websocket", "") + return str.__new__(cls, path) + + def __init__(self, path: str, protocol: str = "", host: str = "") -> None: + self.protocol = protocol + self.host = host + + def make_absolute_url(self, base_url: str | URL) -> URL: + if isinstance(base_url, str): + base_url = URL(base_url) + if self.protocol: + scheme = { + "http": {True: "https", False: "http"}, + "websocket": {True: "wss", False: "ws"}, + }[self.protocol][base_url.is_secure] + else: + scheme = base_url.scheme + + netloc = self.host or base_url.netloc + path = base_url.path.rstrip("/") + str(self) + return URL(scheme=scheme, netloc=netloc, path=path) + + +class Secret: + """ + Holds a string value that should not be revealed in tracebacks etc. + You should cast the value to `str` at the point it is required. + """ + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}('**********')" + + def __str__(self) -> str: + return self._value + + def __bool__(self) -> bool: + return bool(self._value) + + +class CommaSeparatedStrings(typing.Sequence[str]): + def __init__(self, value: str | typing.Sequence[str]): + if isinstance(value, str): + splitter = shlex(value, posix=True) + splitter.whitespace = "," + splitter.whitespace_split = True + self._items = [item.strip() for item in splitter] + else: + self._items = list(value) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, index: int | slice) -> typing.Any: + return self._items[index] + + def __iter__(self) -> typing.Iterator[str]: + return iter(self._items) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + items = [item for item in self] + return f"{class_name}({items!r})" + + def __str__(self) -> str: + return ", ".join(repr(item) for item in self) + + +class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): + _dict: dict[_KeyType, _CovariantValueType] + + def __init__( + self, + *args: ImmutableMultiDict[_KeyType, _CovariantValueType] + | typing.Mapping[_KeyType, _CovariantValueType] + | typing.Iterable[tuple[_KeyType, _CovariantValueType]], + **kwargs: typing.Any, + ) -> None: + assert len(args) < 2, "Too many arguments." + + value: typing.Any = args[0] if args else [] + if kwargs: + value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() + + if not value: + _items: list[tuple[typing.Any, typing.Any]] = [] + elif hasattr(value, "multi_items"): + value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) + _items = list(value.multi_items()) + elif hasattr(value, "items"): + value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) + _items = list(value.items()) + else: + value = typing.cast("list[tuple[typing.Any, typing.Any]]", value) + _items = list(value) + + self._dict = {k: v for k, v in _items} + self._list = _items + + def getlist(self, key: typing.Any) -> list[_CovariantValueType]: + return [item_value for item_key, item_value in self._list if item_key == key] + + def keys(self) -> typing.KeysView[_KeyType]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[_CovariantValueType]: + return self._dict.values() + + def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: + return self._dict.items() + + def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: + return list(self._list) + + def __getitem__(self, key: _KeyType) -> _CovariantValueType: + return self._dict[key] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[_KeyType]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + items = self.multi_items() + return f"{class_name}({items!r})" + + +class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): + def __setitem__(self, key: typing.Any, value: typing.Any) -> None: + self.setlist(key, [value]) + + def __delitem__(self, key: typing.Any) -> None: + self._list = [(k, v) for k, v in self._list if k != key] + del self._dict[key] + + def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + self._list = [(k, v) for k, v in self._list if k != key] + return self._dict.pop(key, default) + + def popitem(self) -> tuple[typing.Any, typing.Any]: + key, value = self._dict.popitem() + self._list = [(k, v) for k, v in self._list if k != key] + return key, value + + def poplist(self, key: typing.Any) -> list[typing.Any]: + values = [v for k, v in self._list if k == key] + self.pop(key) + return values + + def clear(self) -> None: + self._dict.clear() + self._list.clear() + + def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + if key not in self: + self._dict[key] = default + self._list.append((key, default)) + + return self[key] + + def setlist(self, key: typing.Any, values: list[typing.Any]) -> None: + if not values: + self.pop(key, None) + else: + existing_items = [(k, v) for (k, v) in self._list if k != key] + self._list = existing_items + [(key, value) for value in values] + self._dict[key] = values[-1] + + def append(self, key: typing.Any, value: typing.Any) -> None: + self._list.append((key, value)) + self._dict[key] = value + + def update( + self, + *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]], + **kwargs: typing.Any, + ) -> None: + value = MultiDict(*args, **kwargs) + existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] + self._list = existing_items + value.multi_items() + self._dict.update(value) + + +class QueryParams(ImmutableMultiDict[str, str]): + """ + An immutable multidict. + """ + + def __init__( + self, + *args: ImmutableMultiDict[typing.Any, typing.Any] + | typing.Mapping[typing.Any, typing.Any] + | list[tuple[typing.Any, typing.Any]] + | str + | bytes, + **kwargs: typing.Any, + ) -> None: + assert len(args) < 2, "Too many arguments." + + value = args[0] if args else [] + + if isinstance(value, str): + super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) + elif isinstance(value, bytes): + super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) + else: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + self._list = [(str(k), str(v)) for k, v in self._list] + self._dict = {str(k): str(v) for k, v in self._dict.items()} + + def __str__(self) -> str: + return urlencode(self._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + +class UploadFile: + """ + An uploaded file included as part of the request data. + """ + + def __init__( + self, + file: typing.BinaryIO, + *, + size: int | None = None, + filename: str | None = None, + headers: Headers | None = None, + ) -> None: + self.filename = filename + self.file = file + self.size = size + self.headers = headers or Headers() + + @property + def content_type(self) -> str | None: + return self.headers.get("content-type", None) + + @property + def _in_memory(self) -> bool: + # check for SpooledTemporaryFile._rolled + rolled_to_disk = getattr(self.file, "_rolled", True) + return not rolled_to_disk + + async def write(self, data: bytes) -> None: + if self.size is not None: + self.size += len(data) + + if self._in_memory: + self.file.write(data) + else: + await run_in_threadpool(self.file.write, data) + + async def read(self, size: int = -1) -> bytes: + if self._in_memory: + return self.file.read(size) + return await run_in_threadpool(self.file.read, size) + + async def seek(self, offset: int) -> None: + if self._in_memory: + self.file.seek(offset) + else: + await run_in_threadpool(self.file.seek, offset) + + async def close(self) -> None: + if self._in_memory: + self.file.close() + else: + await run_in_threadpool(self.file.close) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"filename={self.filename!r}, " + f"size={self.size!r}, " + f"headers={self.headers!r})" + ) + + +class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): + """ + An immutable multidict, containing both file uploads and text input. + """ + + def __init__( + self, + *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], + **kwargs: str | UploadFile, + ) -> None: + super().__init__(*args, **kwargs) + + async def close(self) -> None: + for key, value in self.multi_items(): + if isinstance(value, UploadFile): + await value.close() + + +class Headers(typing.Mapping[str, str]): + """ + An immutable, case-insensitive multidict. + """ + + def __init__( + self, + headers: typing.Mapping[str, str] | None = None, + raw: list[tuple[bytes, bytes]] | None = None, + scope: typing.MutableMapping[str, typing.Any] | None = None, + ) -> None: + self._list: list[tuple[bytes, bytes]] = [] + if headers is not None: + assert raw is None, 'Cannot set both "headers" and "raw".' + assert scope is None, 'Cannot set both "headers" and "scope".' + self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] + elif raw is not None: + assert scope is None, 'Cannot set both "raw" and "scope".' + self._list = raw + elif scope is not None: + # scope["headers"] isn't necessarily a list + # it might be a tuple or other iterable + self._list = scope["headers"] = list(scope["headers"]) + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + return list(self._list) + + def keys(self) -> list[str]: # type: ignore[override] + return [key.decode("latin-1") for key, value in self._list] + + def values(self) -> list[str]: # type: ignore[override] + return [value.decode("latin-1") for key, value in self._list] + + def items(self) -> list[tuple[str, str]]: # type: ignore[override] + return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] + + def getlist(self, key: str) -> list[str]: + get_header_key = key.lower().encode("latin-1") + return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] + + def mutablecopy(self) -> MutableHeaders: + return MutableHeaders(raw=self._list[:]) + + def __getitem__(self, key: str) -> str: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return header_value.decode("latin-1") + raise KeyError(key) + + def __contains__(self, key: typing.Any) -> bool: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, Headers): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + as_dict = dict(self.items()) + if len(as_dict) == len(self): + return f"{class_name}({as_dict!r})" + return f"{class_name}(raw={self.raw!r})" + + +class MutableHeaders(Headers): + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + found_indexes: list[int] = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + found_indexes.append(idx) + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, set_value) + else: + self._list.append((set_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") + + pop_indexes: list[int] = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == del_key: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + self.update(other) + return self + + def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + new = self.mutablecopy() + new.update(other) + return new + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + return self._list + + def setdefault(self, key: str, value: str) -> str: + """ + If the header `key` does not exist, then set it to `value`. + Returns the header value. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + return item_value.decode("latin-1") + self._list.append((set_key, set_value)) + return value + + def update(self, other: typing.Mapping[str, str]) -> None: + for key, val in other.items(): + self[key] = val + + def append(self, key: str, value: str) -> None: + """ + Append a header, preserving any duplicate entries. + """ + append_key = key.lower().encode("latin-1") + append_value = value.encode("latin-1") + self._list.append((append_key, append_value)) + + def add_vary_header(self, vary: str) -> None: + existing = self.get("vary") + if existing is not None: + vary = ", ".join([existing, vary]) + self["vary"] = vary + + +class State: + """ + An object that can be used to store arbitrary state. + + Used for `request.state` and `app.state`. + """ + + _state: dict[str, typing.Any] + + def __init__(self, state: dict[str, typing.Any] | None = None): + if state is None: + state = {} + super().__setattr__("_state", state) + + def __setattr__(self, key: typing.Any, value: typing.Any) -> None: + self._state[key] = value + + def __getattr__(self, key: typing.Any) -> typing.Any: + try: + return self._state[key] + except KeyError: + message = "'{}' object has no attribute '{}'" + raise AttributeError(message.format(self.__class__.__name__, key)) + + def __delattr__(self, key: typing.Any) -> None: + del self._state[key] diff --git a/.venv/lib/python3.11/site-packages/starlette/exceptions.py b/.venv/lib/python3.11/site-packages/starlette/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad3527b94ce034bb3967b3bf0969403118d4284 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/exceptions.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import http +from collections.abc import Mapping + + +class HTTPException(Exception): + def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None: + if detail is None: + detail = http.HTTPStatus(status_code).phrase + self.status_code = status_code + self.detail = detail + self.headers = headers + + def __str__(self) -> str: + return f"{self.status_code}: {self.detail}" + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" + + +class WebSocketException(Exception): + def __init__(self, code: int, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + def __str__(self) -> str: + return f"{self.code}: {self.reason}" + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(code={self.code!r}, reason={self.reason!r})" diff --git a/.venv/lib/python3.11/site-packages/starlette/py.typed b/.venv/lib/python3.11/site-packages/starlette/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/starlette/requests.py b/.venv/lib/python3.11/site-packages/starlette/requests.py new file mode 100644 index 0000000000000000000000000000000000000000..369f632e56d61c54abbb8b91fab34042032420ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/requests.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import json +import typing +from http import cookies as http_cookies + +import anyio + +from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper +from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State +from starlette.exceptions import HTTPException +from starlette.formparsers import FormParser, MultiPartException, MultiPartParser +from starlette.types import Message, Receive, Scope, Send + +if typing.TYPE_CHECKING: + from python_multipart.multipart import parse_options_header + + from starlette.applications import Starlette + from starlette.routing import Router +else: + try: + try: + from python_multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + from multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + parse_options_header = None + + +SERVER_PUSH_HEADERS_TO_COPY = { + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "user-agent", +} + + +def cookie_parser(cookie_string: str) -> dict[str, str]: + """ + This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. + + It attempts to mimic browser cookie parsing behavior: browsers and web servers + frequently disregard the spec (RFC 6265) when setting and reading cookies, + so we attempt to suit the common scenarios here. + + This function has been adapted from Django 3.1.0. + Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based + on an outdated spec and will fail on lots of input we want to support + """ + cookie_dict: dict[str, str] = {} + for chunk in cookie_string.split(";"): + if "=" in chunk: + key, val = chunk.split("=", 1) + else: + # Assume an empty name per + # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 + key, val = "", chunk + key, val = key.strip(), val.strip() + if key or val: + # unquote using Python's algorithm. + cookie_dict[key] = http_cookies._unquote(val) + return cookie_dict + + +class ClientDisconnect(Exception): + pass + + +class HTTPConnection(typing.Mapping[str, typing.Any]): + """ + A base class for incoming HTTP connections, that is used to provide + any functionality that is common to both `Request` and `WebSocket`. + """ + + def __init__(self, scope: Scope, receive: Receive | None = None) -> None: + assert scope["type"] in ("http", "websocket") + self.scope = scope + + def __getitem__(self, key: str) -> typing.Any: + return self.scope[key] + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.scope) + + def __len__(self) -> int: + return len(self.scope) + + # Don't use the `abc.Mapping.__eq__` implementation. + # Connection instances should never be considered equal + # unless `self is other`. + __eq__ = object.__eq__ + __hash__ = object.__hash__ + + @property + def app(self) -> typing.Any: + return self.scope["app"] + + @property + def url(self) -> URL: + if not hasattr(self, "_url"): # pragma: no branch + self._url = URL(scope=self.scope) + return self._url + + @property + def base_url(self) -> URL: + if not hasattr(self, "_base_url"): + base_url_scope = dict(self.scope) + # This is used by request.url_for, it might be used inside a Mount which + # would have its own child scope with its own root_path, but the base URL + # for url_for should still be the top level app root path. + app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) + path = app_root_path + if not path.endswith("/"): + path += "/" + base_url_scope["path"] = path + base_url_scope["query_string"] = b"" + base_url_scope["root_path"] = app_root_path + self._base_url = URL(scope=base_url_scope) + return self._base_url + + @property + def headers(self) -> Headers: + if not hasattr(self, "_headers"): + self._headers = Headers(scope=self.scope) + return self._headers + + @property + def query_params(self) -> QueryParams: + if not hasattr(self, "_query_params"): # pragma: no branch + self._query_params = QueryParams(self.scope["query_string"]) + return self._query_params + + @property + def path_params(self) -> dict[str, typing.Any]: + return self.scope.get("path_params", {}) + + @property + def cookies(self) -> dict[str, str]: + if not hasattr(self, "_cookies"): + cookies: dict[str, str] = {} + cookie_header = self.headers.get("cookie") + + if cookie_header: + cookies = cookie_parser(cookie_header) + self._cookies = cookies + return self._cookies + + @property + def client(self) -> Address | None: + # client is a 2 item tuple of (host, port), None if missing + host_port = self.scope.get("client") + if host_port is not None: + return Address(*host_port) + return None + + @property + def session(self) -> dict[str, typing.Any]: + assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" + return self.scope["session"] # type: ignore[no-any-return] + + @property + def auth(self) -> typing.Any: + assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" + return self.scope["auth"] + + @property + def user(self) -> typing.Any: + assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" + return self.scope["user"] + + @property + def state(self) -> State: + if not hasattr(self, "_state"): + # Ensure 'state' has an empty dict if it's not already populated. + self.scope.setdefault("state", {}) + # Create a state instance with a reference to the dict in which it should + # store info + self._state = State(self.scope["state"]) + return self._state + + def url_for(self, name: str, /, **path_params: typing.Any) -> URL: + url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") + if url_path_provider is None: + raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") + url_path = url_path_provider.url_path_for(name, **path_params) + return url_path.make_absolute_url(base_url=self.base_url) + + +async def empty_receive() -> typing.NoReturn: + raise RuntimeError("Receive channel has not been made available") + + +async def empty_send(message: Message) -> typing.NoReturn: + raise RuntimeError("Send channel has not been made available") + + +class Request(HTTPConnection): + _form: FormData | None + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): + super().__init__(scope) + assert scope["type"] == "http" + self._receive = receive + self._send = send + self._stream_consumed = False + self._is_disconnected = False + self._form = None + + @property + def method(self) -> str: + return typing.cast(str, self.scope["method"]) + + @property + def receive(self) -> Receive: + return self._receive + + async def stream(self) -> typing.AsyncGenerator[bytes, None]: + if hasattr(self, "_body"): + yield self._body + yield b"" + return + if self._stream_consumed: + raise RuntimeError("Stream consumed") + while not self._stream_consumed: + message = await self._receive() + if message["type"] == "http.request": + body = message.get("body", b"") + if not message.get("more_body", False): + self._stream_consumed = True + if body: + yield body + elif message["type"] == "http.disconnect": # pragma: no branch + self._is_disconnected = True + raise ClientDisconnect() + yield b"" + + async def body(self) -> bytes: + if not hasattr(self, "_body"): + chunks: list[bytes] = [] + async for chunk in self.stream(): + chunks.append(chunk) + self._body = b"".join(chunks) + return self._body + + async def json(self) -> typing.Any: + if not hasattr(self, "_json"): # pragma: no branch + body = await self.body() + self._json = json.loads(body) + return self._json + + async def _get_form( + self, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, + ) -> FormData: + if self._form is None: # pragma: no branch + assert ( + parse_options_header is not None + ), "The `python-multipart` library must be installed to use form parsing." + content_type_header = self.headers.get("Content-Type") + content_type: bytes + content_type, _ = parse_options_header(content_type_header) + if content_type == b"multipart/form-data": + try: + multipart_parser = MultiPartParser( + self.headers, + self.stream(), + max_files=max_files, + max_fields=max_fields, + max_part_size=max_part_size, + ) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if "app" in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc + elif content_type == b"application/x-www-form-urlencoded": + form_parser = FormParser(self.headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() + return self._form + + def form( + self, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, + ) -> AwaitableOrContextManager[FormData]: + return AwaitableOrContextManagerWrapper( + self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) + ) + + async def close(self) -> None: + if self._form is not None: # pragma: no branch + await self._form.close() + + async def is_disconnected(self) -> bool: + if not self._is_disconnected: + message: Message = {} + + # If message isn't immediately available, move on + with anyio.CancelScope() as cs: + cs.cancel() + message = await self._receive() + + if message.get("type") == "http.disconnect": + self._is_disconnected = True + + return self._is_disconnected + + async def send_push_promise(self, path: str) -> None: + if "http.response.push" in self.scope.get("extensions", {}): + raw_headers: list[tuple[bytes, bytes]] = [] + for name in SERVER_PUSH_HEADERS_TO_COPY: + for value in self.headers.getlist(name): + raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) + await self._send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/.venv/lib/python3.11/site-packages/starlette/responses.py b/.venv/lib/python3.11/site-packages/starlette/responses.py new file mode 100644 index 0000000000000000000000000000000000000000..31874f655b96be7e36ddb67faaf4c3bd223f5372 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/responses.py @@ -0,0 +1,537 @@ +from __future__ import annotations + +import hashlib +import http.cookies +import json +import os +import re +import stat +import typing +import warnings +from datetime import datetime +from email.utils import format_datetime, formatdate +from functools import partial +from mimetypes import guess_type +from secrets import token_hex +from urllib.parse import quote + +import anyio +import anyio.to_thread + +from starlette.background import BackgroundTask +from starlette.concurrency import iterate_in_threadpool +from starlette.datastructures import URL, Headers, MutableHeaders +from starlette.requests import ClientDisconnect +from starlette.types import Receive, Scope, Send + + +class Response: + media_type = None + charset = "utf-8" + + def __init__( + self, + content: typing.Any = None, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + self.status_code = status_code + if media_type is not None: + self.media_type = media_type + self.background = background + self.body = self.render(content) + self.init_headers(headers) + + def render(self, content: typing.Any) -> bytes | memoryview: + if content is None: + return b"" + if isinstance(content, (bytes, memoryview)): + return content + return content.encode(self.charset) # type: ignore + + def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None: + if headers is None: + raw_headers: list[tuple[bytes, bytes]] = [] + populate_content_length = True + populate_content_type = True + else: + raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] + keys = [h[0] for h in raw_headers] + populate_content_length = b"content-length" not in keys + populate_content_type = b"content-type" not in keys + + body = getattr(self, "body", None) + if ( + body is not None + and populate_content_length + and not (self.status_code < 200 or self.status_code in (204, 304)) + ): + content_length = str(len(body)) + raw_headers.append((b"content-length", content_length.encode("latin-1"))) + + content_type = self.media_type + if content_type is not None and populate_content_type: + if content_type.startswith("text/") and "charset=" not in content_type.lower(): + content_type += "; charset=" + self.charset + raw_headers.append((b"content-type", content_type.encode("latin-1"))) + + self.raw_headers = raw_headers + + @property + def headers(self) -> MutableHeaders: + if not hasattr(self, "_headers"): + self._headers = MutableHeaders(raw=self.raw_headers) + return self._headers + + def set_cookie( + self, + key: str, + value: str = "", + max_age: int | None = None, + expires: datetime | str | int | None = None, + path: str | None = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: typing.Literal["lax", "strict", "none"] | None = "lax", + ) -> None: + cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() + cookie[key] = value + if max_age is not None: + cookie[key]["max-age"] = max_age + if expires is not None: + if isinstance(expires, datetime): + cookie[key]["expires"] = format_datetime(expires, usegmt=True) + else: + cookie[key]["expires"] = expires + if path is not None: + cookie[key]["path"] = path + if domain is not None: + cookie[key]["domain"] = domain + if secure: + cookie[key]["secure"] = True + if httponly: + cookie[key]["httponly"] = True + if samesite is not None: + assert samesite.lower() in [ + "strict", + "lax", + "none", + ], "samesite must be either 'strict', 'lax' or 'none'" + cookie[key]["samesite"] = samesite + cookie_val = cookie.output(header="").strip() + self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) + + def delete_cookie( + self, + key: str, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: typing.Literal["lax", "strict", "none"] | None = "lax", + ) -> None: + self.set_cookie( + key, + max_age=0, + expires=0, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + samesite=samesite, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + prefix = "websocket." if scope["type"] == "websocket" else "" + await send( + { + "type": prefix + "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + await send({"type": prefix + "http.response.body", "body": self.body}) + + if self.background is not None: + await self.background() + + +class HTMLResponse(Response): + media_type = "text/html" + + +class PlainTextResponse(Response): + media_type = "text/plain" + + +class JSONResponse(Response): + media_type = "application/json" + + def __init__( + self, + content: typing.Any, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + super().__init__(content, status_code, headers, media_type, background) + + def render(self, content: typing.Any) -> bytes: + return json.dumps( + content, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ).encode("utf-8") + + +class RedirectResponse(Response): + def __init__( + self, + url: str | URL, + status_code: int = 307, + headers: typing.Mapping[str, str] | None = None, + background: BackgroundTask | None = None, + ) -> None: + super().__init__(content=b"", status_code=status_code, headers=headers, background=background) + self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") + + +Content = typing.Union[str, bytes, memoryview] +SyncContentStream = typing.Iterable[Content] +AsyncContentStream = typing.AsyncIterable[Content] +ContentStream = typing.Union[AsyncContentStream, SyncContentStream] + + +class StreamingResponse(Response): + body_iterator: AsyncContentStream + + def __init__( + self, + content: ContentStream, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + if isinstance(content, typing.AsyncIterable): + self.body_iterator = content + else: + self.body_iterator = iterate_in_threadpool(content) + self.status_code = status_code + self.media_type = self.media_type if media_type is None else media_type + self.background = background + self.init_headers(headers) + + async def listen_for_disconnect(self, receive: Receive) -> None: + while True: + message = await receive() + if message["type"] == "http.disconnect": + break + + async def stream_response(self, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async for chunk in self.body_iterator: + if not isinstance(chunk, (bytes, memoryview)): + chunk = chunk.encode(self.charset) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) + + if spec_version >= (2, 4): + try: + await self.stream_response(send) + except OSError: + raise ClientDisconnect() + else: + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) + + if self.background is not None: + await self.background() + + +class MalformedRangeHeader(Exception): + def __init__(self, content: str = "Malformed range header.") -> None: + self.content = content + + +class RangeNotSatisfiable(Exception): + def __init__(self, max_size: int) -> None: + self.max_size = max_size + + +_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)") + + +class FileResponse(Response): + chunk_size = 64 * 1024 + + def __init__( + self, + path: str | os.PathLike[str], + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + filename: str | None = None, + stat_result: os.stat_result | None = None, + method: str | None = None, + content_disposition_type: str = "attachment", + ) -> None: + self.path = path + self.status_code = status_code + self.filename = filename + if method is not None: + warnings.warn( + "The 'method' parameter is not used, and it will be removed.", + DeprecationWarning, + ) + if media_type is None: + media_type = guess_type(filename or path)[0] or "text/plain" + self.media_type = media_type + self.background = background + self.init_headers(headers) + self.headers.setdefault("accept-ranges", "bytes") + if self.filename is not None: + content_disposition_filename = quote(self.filename) + if content_disposition_filename != self.filename: + content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" + else: + content_disposition = f'{content_disposition_type}; filename="{self.filename}"' + self.headers.setdefault("content-disposition", content_disposition) + self.stat_result = stat_result + if stat_result is not None: + self.set_stat_headers(stat_result) + + def set_stat_headers(self, stat_result: os.stat_result) -> None: + content_length = str(stat_result.st_size) + last_modified = formatdate(stat_result.st_mtime, usegmt=True) + etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) + etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' + + self.headers.setdefault("content-length", content_length) + self.headers.setdefault("last-modified", last_modified) + self.headers.setdefault("etag", etag) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + send_header_only: bool = scope["method"].upper() == "HEAD" + if self.stat_result is None: + try: + stat_result = await anyio.to_thread.run_sync(os.stat, self.path) + self.set_stat_headers(stat_result) + except FileNotFoundError: + raise RuntimeError(f"File at path {self.path} does not exist.") + else: + mode = stat_result.st_mode + if not stat.S_ISREG(mode): + raise RuntimeError(f"File at path {self.path} is not a file.") + else: + stat_result = self.stat_result + + headers = Headers(scope=scope) + http_range = headers.get("range") + http_if_range = headers.get("if-range") + + if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): + await self._handle_simple(send, send_header_only) + else: + try: + ranges = self._parse_range_header(http_range, stat_result.st_size) + except MalformedRangeHeader as exc: + return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) + except RangeNotSatisfiable as exc: + response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"}) + return await response(scope, receive, send) + + if len(ranges) == 1: + start, end = ranges[0] + await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) + else: + await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) + + if self.background is not None: + await self.background() + + async def _handle_simple(self, send: Send, send_header_only: bool) -> None: + await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) + if send_header_only: + await send({"type": "http.response.body", "body": b"", "more_body": False}) + else: + async with await anyio.open_file(self.path, mode="rb") as file: + more_body = True + while more_body: + chunk = await file.read(self.chunk_size) + more_body = len(chunk) == self.chunk_size + await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + + async def _handle_single_range( + self, send: Send, start: int, end: int, file_size: int, send_header_only: bool + ) -> None: + self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" + self.headers["content-length"] = str(end - start) + await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) + if send_header_only: + await send({"type": "http.response.body", "body": b"", "more_body": False}) + else: + async with await anyio.open_file(self.path, mode="rb") as file: + await file.seek(start) + more_body = True + while more_body: + chunk = await file.read(min(self.chunk_size, end - start)) + start += len(chunk) + more_body = len(chunk) == self.chunk_size and start < end + await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + + async def _handle_multiple_ranges( + self, + send: Send, + ranges: list[tuple[int, int]], + file_size: int, + send_header_only: bool, + ) -> None: + # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). + boundary = token_hex(13) + content_length, header_generator = self.generate_multipart( + ranges, boundary, file_size, self.headers["content-type"] + ) + self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}" + self.headers["content-length"] = str(content_length) + await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) + if send_header_only: + await send({"type": "http.response.body", "body": b"", "more_body": False}) + else: + async with await anyio.open_file(self.path, mode="rb") as file: + for start, end in ranges: + await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) + await file.seek(start) + while start < end: + chunk = await file.read(min(self.chunk_size, end - start)) + start += len(chunk) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + await send({"type": "http.response.body", "body": b"\n", "more_body": True}) + await send( + { + "type": "http.response.body", + "body": f"\n--{boundary}--\n".encode("latin-1"), + "more_body": False, + } + ) + + def _should_use_range(self, http_if_range: str) -> bool: + return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] + + @staticmethod + def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]: + ranges: list[tuple[int, int]] = [] + try: + units, range_ = http_range.split("=", 1) + except ValueError: + raise MalformedRangeHeader() + + units = units.strip().lower() + + if units != "bytes": + raise MalformedRangeHeader("Only support bytes range") + + ranges = [ + ( + int(_[0]) if _[0] else file_size - int(_[1]), + int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size, + ) + for _ in _RANGE_PATTERN.findall(range_) + if _ != ("", "") + ] + + if len(ranges) == 0: + raise MalformedRangeHeader("Range header: range must be requested") + + if any(not (0 <= start < file_size) for start, _ in ranges): + raise RangeNotSatisfiable(file_size) + + if any(start > end for start, end in ranges): + raise MalformedRangeHeader("Range header: start must be less than end") + + if len(ranges) == 1: + return ranges + + # Merge ranges + result: list[tuple[int, int]] = [] + for start, end in ranges: + for p in range(len(result)): + p_start, p_end = result[p] + if start > p_end: + continue + elif end < p_start: + result.insert(p, (start, end)) # THIS IS NOT REACHED! + break + else: + result[p] = (min(start, p_start), max(end, p_end)) + break + else: + result.append((start, end)) + + return result + + def generate_multipart( + self, + ranges: typing.Sequence[tuple[int, int]], + boundary: str, + max_size: int, + content_type: str, + ) -> tuple[int, typing.Callable[[int, int], bytes]]: + r""" + Multipart response headers generator. + + ``` + --{boundary}\n + Content-Type: {content_type}\n + Content-Range: bytes {start}-{end-1}/{max_size}\n + \n + ..........content...........\n + --{boundary}\n + Content-Type: {content_type}\n + Content-Range: bytes {start}-{end-1}/{max_size}\n + \n + ..........content...........\n + --{boundary}--\n + ``` + """ + boundary_len = len(boundary) + static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size)) + content_length = sum( + (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers + + (end - start) # Content + for start, end in ranges + ) + ( + 5 + boundary_len # --boundary--\n + ) + return ( + content_length, + lambda start, end: ( + f"--{boundary}\n" + f"Content-Type: {content_type}\n" + f"Content-Range: bytes {start}-{end-1}/{max_size}\n" + "\n" + ).encode("latin-1"), + ) diff --git a/.venv/lib/python3.11/site-packages/starlette/schemas.py b/.venv/lib/python3.11/site-packages/starlette/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc40e2ae0efddee75d0328cbe83720b8d80d886 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/schemas.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import inspect +import re +import typing + +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import BaseRoute, Host, Mount, Route + +try: + import yaml +except ModuleNotFoundError: # pragma: no cover + yaml = None # type: ignore[assignment] + + +class OpenAPIResponse(Response): + media_type = "application/vnd.oai.openapi" + + def render(self, content: typing.Any) -> bytes: + assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." + assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary." + return yaml.dump(content, default_flow_style=False).encode("utf-8") + + +class EndpointInfo(typing.NamedTuple): + path: str + http_method: str + func: typing.Callable[..., typing.Any] + + +_remove_converter_pattern = re.compile(r":\w+}") + + +class BaseSchemaGenerator: + def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: + raise NotImplementedError() # pragma: no cover + + def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: + """ + Given the routes, yields the following information: + + - path + eg: /users/ + - http_method + one of 'get', 'post', 'put', 'patch', 'delete', 'options' + - func + method ready to extract the docstring + """ + endpoints_info: list[EndpointInfo] = [] + + for route in routes: + if isinstance(route, (Mount, Host)): + routes = route.routes or [] + if isinstance(route, Mount): + path = self._remove_converter(route.path) + else: + path = "" + sub_endpoints = [ + EndpointInfo( + path="".join((path, sub_endpoint.path)), + http_method=sub_endpoint.http_method, + func=sub_endpoint.func, + ) + for sub_endpoint in self.get_endpoints(routes) + ] + endpoints_info.extend(sub_endpoints) + + elif not isinstance(route, Route) or not route.include_in_schema: + continue + + elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): + path = self._remove_converter(route.path) + for method in route.methods or ["GET"]: + if method == "HEAD": + continue + endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint)) + else: + path = self._remove_converter(route.path) + for method in ["get", "post", "put", "patch", "delete", "options"]: + if not hasattr(route.endpoint, method): + continue + func = getattr(route.endpoint, method) + endpoints_info.append(EndpointInfo(path, method.lower(), func)) + + return endpoints_info + + def _remove_converter(self, path: str) -> str: + """ + Remove the converter from the path. + For example, a route like this: + Route("/users/{id:int}", endpoint=get_user, methods=["GET"]) + Should be represented as `/users/{id}` in the OpenAPI schema. + """ + return _remove_converter_pattern.sub("}", path) + + def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]: + """ + Given a function, parse the docstring as YAML and return a dictionary of info. + """ + docstring = func_or_method.__doc__ + if not docstring: + return {} + + assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." + + # We support having regular docstrings before the schema + # definition. Here we return just the schema part from + # the docstring. + docstring = docstring.split("---")[-1] + + parsed = yaml.safe_load(docstring) + + if not isinstance(parsed, dict): + # A regular docstring (not yaml formatted) can return + # a simple string here, which wouldn't follow the schema. + return {} + + return parsed + + def OpenAPIResponse(self, request: Request) -> Response: + routes = request.app.routes + schema = self.get_schema(routes=routes) + return OpenAPIResponse(schema) + + +class SchemaGenerator(BaseSchemaGenerator): + def __init__(self, base_schema: dict[str, typing.Any]) -> None: + self.base_schema = base_schema + + def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: + schema = dict(self.base_schema) + schema.setdefault("paths", {}) + endpoints_info = self.get_endpoints(routes) + + for endpoint in endpoints_info: + parsed = self.parse_docstring(endpoint.func) + + if not parsed: + continue + + if endpoint.path not in schema["paths"]: + schema["paths"][endpoint.path] = {} + + schema["paths"][endpoint.path][endpoint.http_method] = parsed + + return schema diff --git a/.venv/lib/python3.11/site-packages/starlette/staticfiles.py b/.venv/lib/python3.11/site-packages/starlette/staticfiles.py new file mode 100644 index 0000000000000000000000000000000000000000..34be04cdc7c486a66944e147b7c775f44402f599 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/staticfiles.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import errno +import importlib.util +import os +import stat +import typing +from email.utils import parsedate + +import anyio +import anyio.to_thread + +from starlette._utils import get_route_path +from starlette.datastructures import URL, Headers +from starlette.exceptions import HTTPException +from starlette.responses import FileResponse, RedirectResponse, Response +from starlette.types import Receive, Scope, Send + +PathLike = typing.Union[str, "os.PathLike[str]"] + + +class NotModifiedResponse(Response): + NOT_MODIFIED_HEADERS = ( + "cache-control", + "content-location", + "date", + "etag", + "expires", + "vary", + ) + + def __init__(self, headers: Headers): + super().__init__( + status_code=304, + headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS}, + ) + + +class StaticFiles: + def __init__( + self, + *, + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, + html: bool = False, + check_dir: bool = True, + follow_symlink: bool = False, + ) -> None: + self.directory = directory + self.packages = packages + self.all_directories = self.get_directories(directory, packages) + self.html = html + self.config_checked = False + self.follow_symlink = follow_symlink + if check_dir and directory is not None and not os.path.isdir(directory): + raise RuntimeError(f"Directory '{directory}' does not exist") + + def get_directories( + self, + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, + ) -> list[PathLike]: + """ + Given `directory` and `packages` arguments, return a list of all the + directories that should be used for serving static files from. + """ + directories = [] + if directory is not None: + directories.append(directory) + + for package in packages or []: + if isinstance(package, tuple): + package, statics_dir = package + else: + statics_dir = "statics" + spec = importlib.util.find_spec(package) + assert spec is not None, f"Package {package!r} could not be found." + assert spec.origin is not None, f"Package {package!r} could not be found." + package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir)) + assert os.path.isdir( + package_directory + ), f"Directory '{statics_dir!r}' in package {package!r} could not be found." + directories.append(package_directory) + + return directories + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + The ASGI entry point. + """ + assert scope["type"] == "http" + + if not self.config_checked: + await self.check_config() + self.config_checked = True + + path = self.get_path(scope) + response = await self.get_response(path, scope) + await response(scope, receive, send) + + def get_path(self, scope: Scope) -> str: + """ + Given the ASGI scope, return the `path` string to serve up, + with OS specific path separators, and any '..', '.' components removed. + """ + route_path = get_route_path(scope) + return os.path.normpath(os.path.join(*route_path.split("/"))) + + async def get_response(self, path: str, scope: Scope) -> Response: + """ + Returns an HTTP response, given the incoming path, method and request headers. + """ + if scope["method"] not in ("GET", "HEAD"): + raise HTTPException(status_code=405) + + try: + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path) + except PermissionError: + raise HTTPException(status_code=401) + except OSError as exc: + # Filename is too long, so it can't be a valid static file. + if exc.errno == errno.ENAMETOOLONG: + raise HTTPException(status_code=404) + + raise exc + + if stat_result and stat.S_ISREG(stat_result.st_mode): + # We have a static file to serve. + return self.file_response(full_path, stat_result, scope) + + elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: + # We're in HTML mode, and have got a directory URL. + # Check if we have 'index.html' file to serve. + index_path = os.path.join(path, "index.html") + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path) + if stat_result is not None and stat.S_ISREG(stat_result.st_mode): + if not scope["path"].endswith("/"): + # Directory URLs should redirect to always end in "/". + url = URL(scope=scope) + url = url.replace(path=url.path + "/") + return RedirectResponse(url=url) + return self.file_response(full_path, stat_result, scope) + + if self.html: + # Check for '404.html' if we're in HTML mode. + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html") + if stat_result and stat.S_ISREG(stat_result.st_mode): + return FileResponse(full_path, stat_result=stat_result, status_code=404) + raise HTTPException(status_code=404) + + def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: + for directory in self.all_directories: + joined_path = os.path.join(directory, path) + if self.follow_symlink: + full_path = os.path.abspath(joined_path) + else: + full_path = os.path.realpath(joined_path) + directory = os.path.realpath(directory) + if os.path.commonpath([full_path, directory]) != str(directory): + # Don't allow misbehaving clients to break out of the static files directory. + continue + try: + return full_path, os.stat(full_path) + except (FileNotFoundError, NotADirectoryError): + continue + return "", None + + def file_response( + self, + full_path: PathLike, + stat_result: os.stat_result, + scope: Scope, + status_code: int = 200, + ) -> Response: + request_headers = Headers(scope=scope) + + response = FileResponse(full_path, status_code=status_code, stat_result=stat_result) + if self.is_not_modified(response.headers, request_headers): + return NotModifiedResponse(response.headers) + return response + + async def check_config(self) -> None: + """ + Perform a one-off configuration check that StaticFiles is actually + pointed at a directory, so that we can raise loud errors rather than + just returning 404 responses. + """ + if self.directory is None: + return + + try: + stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) + except FileNotFoundError: + raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.") + if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): + raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.") + + def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool: + """ + Given the request and response headers, return `True` if an HTTP + "Not Modified" response could be returned instead. + """ + try: + if_none_match = request_headers["if-none-match"] + etag = response_headers["etag"] + if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]: + return True + except KeyError: + pass + + try: + if_modified_since = parsedate(request_headers["if-modified-since"]) + last_modified = parsedate(response_headers["last-modified"]) + if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified: + return True + except KeyError: + pass + + return False diff --git a/.venv/lib/python3.11/site-packages/starlette/templating.py b/.venv/lib/python3.11/site-packages/starlette/templating.py new file mode 100644 index 0000000000000000000000000000000000000000..6b01aac9209fdcc893f4b24fd09d8a3f14149ec9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/templating.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import typing +import warnings +from os import PathLike + +from starlette.background import BackgroundTask +from starlette.datastructures import URL +from starlette.requests import Request +from starlette.responses import HTMLResponse +from starlette.types import Receive, Scope, Send + +try: + import jinja2 + + # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1 + # hence we try to get pass_context (most installs will be >=3.1) + # and fall back to contextfunction, + # adding a type ignore for mypy to let us access an attribute that may not exist + if hasattr(jinja2, "pass_context"): + pass_context = jinja2.pass_context + else: # pragma: no cover + pass_context = jinja2.contextfunction # type: ignore[attr-defined] +except ModuleNotFoundError: # pragma: no cover + jinja2 = None # type: ignore[assignment] + + +class _TemplateResponse(HTMLResponse): + def __init__( + self, + template: typing.Any, + context: dict[str, typing.Any], + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ): + self.template = template + self.context = context + content = template.render(context) + super().__init__(content, status_code, headers, media_type, background) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + request = self.context.get("request", {}) + extensions = request.get("extensions", {}) + if "http.response.debug" in extensions: # pragma: no branch + await send( + { + "type": "http.response.debug", + "info": { + "template": self.template, + "context": self.context, + }, + } + ) + await super().__call__(scope, receive, send) + + +class Jinja2Templates: + """ + templates = Jinja2Templates("templates") + + return templates.TemplateResponse("index.html", {"request": request}) + """ + + @typing.overload + def __init__( + self, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], + *, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + **env_options: typing.Any, + ) -> None: ... + + @typing.overload + def __init__( + self, + *, + env: jinja2.Environment, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + ) -> None: ... + + def __init__( + self, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None, + *, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + env: jinja2.Environment | None = None, + **env_options: typing.Any, + ) -> None: + if env_options: + warnings.warn( + "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", + DeprecationWarning, + ) + assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" + assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed" + self.context_processors = context_processors or [] + if directory is not None: + self.env = self._create_env(directory, **env_options) + elif env is not None: # pragma: no branch + self.env = env + + self._setup_env_defaults(self.env) + + def _create_env( + self, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], + **env_options: typing.Any, + ) -> jinja2.Environment: + loader = jinja2.FileSystemLoader(directory) + env_options.setdefault("loader", loader) + env_options.setdefault("autoescape", True) + + return jinja2.Environment(**env_options) + + def _setup_env_defaults(self, env: jinja2.Environment) -> None: + @pass_context + def url_for( + context: dict[str, typing.Any], + name: str, + /, + **path_params: typing.Any, + ) -> URL: + request: Request = context["request"] + return request.url_for(name, **path_params) + + env.globals.setdefault("url_for", url_for) + + def get_template(self, name: str) -> jinja2.Template: + return self.env.get_template(name) + + @typing.overload + def TemplateResponse( + self, + request: Request, + name: str, + context: dict[str, typing.Any] | None = None, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> _TemplateResponse: ... + + @typing.overload + def TemplateResponse( + self, + name: str, + context: dict[str, typing.Any] | None = None, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> _TemplateResponse: + # Deprecated usage + ... + + def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse: + if args: + if isinstance(args[0], str): # the first argument is template name (old style) + warnings.warn( + "The `name` is not the first parameter anymore. " + "The first parameter should be the `Request` instance.\n" + 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', + DeprecationWarning, + ) + + name = args[0] + context = args[1] if len(args) > 1 else kwargs.get("context", {}) + status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200) + headers = args[2] if len(args) > 2 else kwargs.get("headers") + media_type = args[3] if len(args) > 3 else kwargs.get("media_type") + background = args[4] if len(args) > 4 else kwargs.get("background") + + if "request" not in context: + raise ValueError('context must include a "request" key') + request = context["request"] + else: # the first argument is a request instance (new style) + request = args[0] + name = args[1] if len(args) > 1 else kwargs["name"] + context = args[2] if len(args) > 2 else kwargs.get("context", {}) + status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200) + headers = args[4] if len(args) > 4 else kwargs.get("headers") + media_type = args[5] if len(args) > 5 else kwargs.get("media_type") + background = args[6] if len(args) > 6 else kwargs.get("background") + else: # all arguments are kwargs + if "request" not in kwargs: + warnings.warn( + "The `TemplateResponse` now requires the `request` argument.\n" + 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', + DeprecationWarning, + ) + if "request" not in kwargs.get("context", {}): + raise ValueError('context must include a "request" key') + + context = kwargs.get("context", {}) + request = kwargs.get("request", context.get("request")) + name = typing.cast(str, kwargs["name"]) + status_code = kwargs.get("status_code", 200) + headers = kwargs.get("headers") + media_type = kwargs.get("media_type") + background = kwargs.get("background") + + context.setdefault("request", request) + for context_processor in self.context_processors: + context.update(context_processor(request)) + + template = self.get_template(name) + return _TemplateResponse( + template, + context, + status_code=status_code, + headers=headers, + media_type=media_type, + background=background, + ) diff --git a/.venv/lib/python3.11/site-packages/starlette/testclient.py b/.venv/lib/python3.11/site-packages/starlette/testclient.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2ad6e3555b62433c27c48fd9760dabdad52801 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/testclient.py @@ -0,0 +1,724 @@ +from __future__ import annotations + +import contextlib +import inspect +import io +import json +import math +import sys +import typing +from concurrent.futures import Future +from types import GeneratorType +from urllib.parse import unquote, urljoin + +import anyio +import anyio.abc +import anyio.from_thread +from anyio.streams.stapled import StapledObjectStream + +from starlette._utils import is_async_callable +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocketDisconnect + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: # pragma: no cover + from typing_extensions import TypeGuard + +try: + import httpx +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + "The starlette.testclient module requires the httpx package to be installed.\n" + "You can install this with:\n" + " $ pip install httpx\n" + ) +_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]] + +ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] +ASGI2App = typing.Callable[[Scope], ASGIInstance] +ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + + +_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]] + + +def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: + if inspect.isclass(app): + return hasattr(app, "__await__") + return is_async_callable(app) + + +class _WrapASGI2: + """ + Provide an ASGI3 interface onto an ASGI2 app. + """ + + def __init__(self, app: ASGI2App) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + instance = self.app(scope) + await instance(receive, send) + + +class _AsyncBackend(typing.TypedDict): + backend: str + backend_options: dict[str, typing.Any] + + +class _Upgrade(Exception): + def __init__(self, session: WebSocketTestSession) -> None: + self.session = session + + +class WebSocketDenialResponse( # type: ignore[misc] + httpx.Response, + WebSocketDisconnect, +): + """ + A special case of `WebSocketDisconnect`, raised in the `TestClient` if the + `WebSocket` is closed before being accepted with a `send_denial_response()`. + """ + + +class WebSocketTestSession: + def __init__( + self, + app: ASGI3App, + scope: Scope, + portal_factory: _PortalFactoryType, + ) -> None: + self.app = app + self.scope = scope + self.accepted_subprotocol = None + self.portal_factory = portal_factory + self.extra_headers = None + + def __enter__(self) -> WebSocketTestSession: + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(self.portal_factory()) + fut, cs = portal.start_task(self._run) + stack.callback(fut.result) + stack.callback(portal.call, cs.cancel) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.callback(self.close, 1000) + self.exit_stack = stack.pop_all() + return self + + def __exit__(self, *args: typing.Any) -> bool | None: + return self.exit_stack.__exit__(*args) + + async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: + """ + The sub-thread in which the websocket session runs. + """ + send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + send_tx, send_rx = send + receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + receive_tx, receive_rx = receive + with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: + self._receive_tx = receive_tx + self._send_rx = send_rx + task_status.started(cs) + await self.app(self.scope, receive_rx.receive, send_tx.send) + + # wait for cs.cancel to be called before closing streams + await anyio.sleep_forever() + + def _raise_on_close(self, message: Message) -> None: + if message["type"] == "websocket.close": + raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) + elif message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) + + def send(self, message: Message) -> None: + self.portal.call(self._receive_tx.send, message) + + def send_text(self, data: str) -> None: + self.send({"type": "websocket.receive", "text": data}) + + def send_bytes(self, data: bytes) -> None: + self.send({"type": "websocket.receive", "bytes": data}) + + def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: + text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) + if mode == "text": + self.send({"type": "websocket.receive", "text": text}) + else: + self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) + + def close(self, code: int = 1000, reason: str | None = None) -> None: + self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) + + def receive(self) -> Message: + return self.portal.call(self._send_rx.receive) + + def receive_text(self) -> str: + message = self.receive() + self._raise_on_close(message) + return typing.cast(str, message["text"]) + + def receive_bytes(self) -> bytes: + message = self.receive() + self._raise_on_close(message) + return typing.cast(bytes, message["bytes"]) + + def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: + message = self.receive() + self._raise_on_close(message) + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + +class _TestClientTransport(httpx.BaseTransport): + def __init__( + self, + app: ASGI3App, + portal_factory: _PortalFactoryType, + raise_server_exceptions: bool = True, + root_path: str = "", + *, + client: tuple[str, int], + app_state: dict[str, typing.Any], + ) -> None: + self.app = app + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + self.portal_factory = portal_factory + self.app_state = app_state + self.client = client + + def handle_request(self, request: httpx.Request) -> httpx.Response: + scheme = request.url.scheme + netloc = request.url.netloc.decode(encoding="ascii") + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + # Include the 'host' header. + if "host" in request.headers: + headers: list[tuple[bytes, bytes]] = [] + elif port == default_port: # pragma: no cover + headers = [(b"host", host.encode())] + else: # pragma: no cover + headers = [(b"host", (f"{host}:{port}").encode())] + + # Include other request headers. + headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] + + scope: dict[str, typing.Any] + + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols: typing.Sequence[str] = [] + else: + subprotocols = [value.strip() for value in subprotocol.split(",")] + scope = { + "type": "websocket", + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "subprotocols": subprotocols, + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + session = WebSocketTestSession(self.app, scope, self.portal_factory) + raise _Upgrade(session) + + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "extensions": {"http.response.debug": {}}, + "state": self.app_state.copy(), + } + + request_complete = False + response_started = False + response_complete: anyio.Event + raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} + template = None + context = None + + async def receive() -> Message: + nonlocal request_complete + + if request_complete: + if not response_complete.is_set(): + await response_complete.wait() + return {"type": "http.disconnect"} + + body = request.read() + if isinstance(body, str): + body_bytes: bytes = body.encode("utf-8") # pragma: no cover + elif body is None: + body_bytes = b"" # pragma: no cover + elif isinstance(body, GeneratorType): + try: # pragma: no cover + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: # pragma: no cover + request_complete = True + return {"type": "http.request", "body": b""} + else: + body_bytes = body + + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message: Message) -> None: + nonlocal raw_kwargs, response_started, template, context + + if message["type"] == "http.response.start": + assert not response_started, 'Received multiple "http.response.start" messages.' + raw_kwargs["status_code"] = message["status"] + raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] + response_started = True + elif message["type"] == "http.response.body": + assert response_started, 'Received "http.response.body" without "http.response.start".' + assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["stream"].write(body) + if not more_body: + raw_kwargs["stream"].seek(0) + response_complete.set() + elif message["type"] == "http.response.debug": + template = message["info"]["template"] + context = message["info"]["context"] + + try: + with self.portal_factory() as portal: + response_complete = portal.call(anyio.Event) + portal.call(self.app, scope, receive, send) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "status_code": 500, + "headers": [], + "stream": io.BytesIO(), + } + + raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) + + response = httpx.Response(**raw_kwargs, request=request) + if template is not None: + response.template = template # type: ignore[attr-defined] + response.context = context # type: ignore[attr-defined] + return response + + +class TestClient(httpx.Client): + __test__ = False + task: Future[None] + portal: anyio.abc.BlockingPortal | None = None + + def __init__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: typing.Literal["asyncio", "trio"] = "asyncio", + backend_options: dict[str, typing.Any] | None = None, + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + client: tuple[str, int] = ("testclient", 50000), + ) -> None: + self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) + if _is_asgi3(app): + asgi_app = app + else: + app = typing.cast(ASGI2App, app) # type: ignore[assignment] + asgi_app = _WrapASGI2(app) # type: ignore[arg-type] + self.app = asgi_app + self.app_state: dict[str, typing.Any] = {} + transport = _TestClientTransport( + self.app, + portal_factory=self._portal_factory, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + app_state=self.app_state, + client=client, + ) + if headers is None: + headers = {} + headers.setdefault("user-agent", "testclient") + super().__init__( + base_url=base_url, + headers=headers, + transport=transport, + follow_redirects=follow_redirects, + cookies=cookies, + ) + + @contextlib.contextmanager + def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + if self.portal is not None: + yield self.portal + else: + with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: + yield portal + + def request( # type: ignore[override] + self, + method: str, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + url = self._merge_url(url) + return super().request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def get( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().get( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def options( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().options( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def head( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().head( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def post( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().post( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def put( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().put( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def patch( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().patch( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def delete( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().delete( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def websocket_connect( + self, + url: str, + subprotocols: typing.Sequence[str] | None = None, + **kwargs: typing.Any, + ) -> WebSocketTestSession: + url = urljoin("ws://testserver", url) + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + kwargs["headers"] = headers + try: + super().request("GET", url, **kwargs) + except _Upgrade as exc: + session = exc.session + else: + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + return session + + def __enter__(self) -> TestClient: + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) + + @stack.callback + def reset_portal() -> None: + self.portal = None + + send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = ( + anyio.create_memory_object_stream(math.inf) + ) + receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = ( + anyio.create_memory_object_stream(math.inf) + ) + for channel in (*send, *receive): + stack.callback(channel.close) + self.stream_send = StapledObjectStream(*send) + self.stream_receive = StapledObjectStream(*receive) + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.wait_shutdown) + + self.exit_stack = stack.pop_all() + + return self + + def __exit__(self, *args: typing.Any) -> None: + self.exit_stack.close() + + async def lifespan(self) -> None: + scope = {"type": "lifespan", "state": self.app_state} + try: + await self.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + await self.stream_send.send(None) + + async def wait_startup(self) -> None: + await self.stream_receive.send({"type": "lifespan.startup"}) + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + message = await receive() + assert message["type"] in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ) + if message["type"] == "lifespan.startup.failed": + await receive() + + async def wait_shutdown(self) -> None: + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/.venv/lib/python3.11/site-packages/starlette/types.py b/.venv/lib/python3.11/site-packages/starlette/types.py new file mode 100644 index 0000000000000000000000000000000000000000..893f872964c2e6df9a81fdba3dd8fadfeaab9731 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/types.py @@ -0,0 +1,24 @@ +import typing + +if typing.TYPE_CHECKING: + from starlette.requests import Request + from starlette.responses import Response + from starlette.websockets import WebSocket + +AppType = typing.TypeVar("AppType") + +Scope = typing.MutableMapping[str, typing.Any] +Message = typing.MutableMapping[str, typing.Any] + +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Message], typing.Awaitable[None]] + +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + +StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]] +StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]] +Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] + +HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"] +WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]] +ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler] diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__init__.py b/.venv/lib/python3.11/site-packages/watchfiles/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..877fbd573e66d63349a03b5b829f0d52467137a3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/__init__.py @@ -0,0 +1,17 @@ +from .filters import BaseFilter, DefaultFilter, PythonFilter +from .main import Change, awatch, watch +from .run import arun_process, run_process +from .version import VERSION + +__version__ = VERSION +__all__ = ( + 'watch', + 'awatch', + 'run_process', + 'arun_process', + 'Change', + 'BaseFilter', + 'DefaultFilter', + 'PythonFilter', + 'VERSION', +) diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__main__.py b/.venv/lib/python3.11/site-packages/watchfiles/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d396c2a7cb4a9c69d07b596337b1fffb249c0191 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/__main__.py @@ -0,0 +1,4 @@ +from .cli import cli + +if __name__ == '__main__': + cli() diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7492105e6f982ab5ced7ae6add5524f104ff9493 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/__main__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/__main__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f1a6905ec13f83d79839217c0f93dfc48963c36 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/__main__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/cli.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290a6e4d197070018de9fcebb3ec19d292c0290b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/cli.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/filters.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/filters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6353bc11fa92920cee820c265df76c9c62876e5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/filters.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/main.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/main.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70e310c30b3a822dc7e85406d653662b2d13b1b0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/main.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/run.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/run.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57d13afbf57be0a73c57b13f3bcec70d2776b001 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/run.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/version.cpython-311.pyc b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..064277262b0fd3d0069c02e5c62181f8b2876671 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/watchfiles/__pycache__/version.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/_rust_notify.pyi b/.venv/lib/python3.11/site-packages/watchfiles/_rust_notify.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e08cfff88f9897f616ba8e268052a13ccbf1cae8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/_rust_notify.pyi @@ -0,0 +1,111 @@ +from typing import Any, Literal, Protocol + +__all__ = 'RustNotify', 'WatchfilesRustInternalError' + +__version__: str +"""The package version as defined in `Cargo.toml`, modified to match python's versioning semantics.""" + +class AbstractEvent(Protocol): + def is_set(self) -> bool: ... + +class RustNotify: + """ + Interface to the Rust [notify](https://crates.io/crates/notify) crate which does + the heavy lifting of watching for file changes and grouping them into events. + """ + + def __init__( + self, + watch_paths: list[str], + debug: bool, + force_polling: bool, + poll_delay_ms: int, + recursive: bool, + ignore_permission_denied: bool, + ) -> None: + """ + Create a new `RustNotify` instance and start a thread to watch for changes. + + `FileNotFoundError` is raised if any of the paths do not exist. + + Args: + watch_paths: file system paths to watch for changes, can be directories or files + debug: if true, print details about all events to stderr + force_polling: if true, always use polling instead of file system notifications + poll_delay_ms: delay between polling for changes, only used if `force_polling=True` + recursive: if `True`, watch for changes in sub-directories recursively, otherwise watch only for changes in + the top-level directory, default is `True`. + ignore_permission_denied: if `True`, permission denied errors are ignored while watching changes. + """ + def watch( + self, + debounce_ms: int, + step_ms: int, + timeout_ms: int, + stop_event: AbstractEvent | None, + ) -> set[tuple[int, str]] | Literal['signal', 'stop', 'timeout']: + """ + Watch for changes. + + This method will wait `timeout_ms` milliseconds for changes, but once a change is detected, + it will group changes and return in no more than `debounce_ms` milliseconds. + + The GIL is released during a `step_ms` sleep on each iteration to avoid + blocking python. + + Args: + debounce_ms: maximum time in milliseconds to group changes over before returning. + step_ms: time to wait for new changes in milliseconds, if no changes are detected + in this time, and at least one change has been detected, the changes are yielded. + timeout_ms: maximum time in milliseconds to wait for changes before returning, + `0` means wait indefinitely, `debounce_ms` takes precedence over `timeout_ms` once + a change is detected. + stop_event: event to check on every iteration to see if this function should return early. + The event should be an object which has an `is_set()` method which returns a boolean. + + Returns: + See below. + + Return values have the following meanings: + + * Change details as a `set` of `(event_type, path)` tuples, the event types are ints which match + [`Change`][watchfiles.Change], `path` is a string representing the path of the file that changed + * `'signal'` string, if a signal was received + * `'stop'` string, if the `stop_event` was set + * `'timeout'` string, if `timeout_ms` was exceeded + """ + def __enter__(self) -> RustNotify: + """ + Does nothing, but allows `RustNotify` to be used as a context manager. + + !!! note + + The watching thead is created when an instance is initiated, not on `__enter__`. + """ + def __exit__(self, *args: Any) -> None: + """ + Calls [`close`][watchfiles._rust_notify.RustNotify.close]. + """ + def close(self) -> None: + """ + Stops the watching thread. After `close` is called, the `RustNotify` instance can no + longer be used, calls to [`watch`][watchfiles._rust_notify.RustNotify.watch] will raise a `RuntimeError`. + + !!! note + + `close` is not required, just deleting the `RustNotify` instance will kill the thread + implicitly. + + As per [#163](https://github.com/samuelcolvin/watchfiles/issues/163) `close()` is only required because + in the event of an error, the traceback in `sys.exc_info` keeps a reference to `watchfiles.watch`'s + frame, so you can't rely on the `RustNotify` object being deleted, and thereby stopping + the watching thread. + """ + +class WatchfilesRustInternalError(RuntimeError): + """ + Raised when RustNotify encounters an unknown error. + + If you get this a lot, please check [github](https://github.com/samuelcolvin/watchfiles/issues) issues + and create a new issue if your problem is not discussed. + """ diff --git a/.venv/lib/python3.11/site-packages/watchfiles/cli.py b/.venv/lib/python3.11/site-packages/watchfiles/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e1ddd57532543967def8d5a8799b8e54d6e233 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/cli.py @@ -0,0 +1,224 @@ +import argparse +import logging +import os +import shlex +import sys +from pathlib import Path +from textwrap import dedent +from typing import Any, Callable, List, Optional, Tuple, Union, cast + +from . import Change +from .filters import BaseFilter, DefaultFilter, PythonFilter +from .run import detect_target_type, import_string, run_process +from .version import VERSION + +logger = logging.getLogger('watchfiles.cli') + + +def resolve_path(path_str: str) -> Path: + path = Path(path_str) + if not path.exists(): + raise FileNotFoundError(path) + else: + return path.resolve() + + +def cli(*args_: str) -> None: + """ + Watch one or more directories and execute either a shell command or a python function on file changes. + + Example of watching the current directory and calling a python function: + + watchfiles foobar.main + + Example of watching python files in two local directories and calling a shell command: + + watchfiles --filter python 'pytest --lf' src tests + + See https://watchfiles.helpmanual.io/cli/ for more information. + """ + args = args_ or sys.argv[1:] + parser = argparse.ArgumentParser( + prog='watchfiles', + description=dedent((cli.__doc__ or '').strip('\n')), + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument('target', help='Command or dotted function path to run') + parser.add_argument( + 'paths', nargs='*', default='.', help='Filesystem paths to watch, defaults to current directory' + ) + + parser.add_argument( + '--ignore-paths', + nargs='?', + type=str, + help=( + 'Specify directories to ignore, ' + 'to ignore multiple paths use a comma as separator, e.g. "env" or "env,node_modules"' + ), + ) + parser.add_argument( + '--target-type', + nargs='?', + type=str, + default='auto', + choices=['command', 'function', 'auto'], + help=( + 'Whether the target should be intercepted as a shell command or a python function, ' + 'defaults to "auto" which infers the target type from the target string' + ), + ) + parser.add_argument( + '--filter', + nargs='?', + type=str, + default='default', + help=( + 'Which files to watch, defaults to "default" which uses the "DefaultFilter", ' + '"python" uses the "PythonFilter", "all" uses no filter, ' + 'any other value is interpreted as a python function/class path which is imported' + ), + ) + parser.add_argument( + '--args', + nargs='?', + type=str, + help='Arguments to set on sys.argv before calling target function, used only if the target is a function', + ) + parser.add_argument('--verbose', action='store_true', help='Set log level to "debug", wins over `--verbosity`') + parser.add_argument( + '--non-recursive', action='store_true', help='Do not watch for changes in sub-directories recursively' + ) + parser.add_argument( + '--verbosity', + nargs='?', + type=str, + default='info', + choices=['warning', 'info', 'debug'], + help='Log level, defaults to "info"', + ) + parser.add_argument( + '--sigint-timeout', + nargs='?', + type=int, + default=5, + help='How long to wait for the sigint timeout before sending sigkill.', + ) + parser.add_argument( + '--grace-period', + nargs='?', + type=float, + default=0, + help='Number of seconds after the process is started before watching for changes.', + ) + parser.add_argument( + '--sigkill-timeout', + nargs='?', + type=int, + default=1, + help='How long to wait for the sigkill timeout before issuing a timeout exception.', + ) + parser.add_argument( + '--ignore-permission-denied', + action='store_true', + help='Ignore permission denied errors while watching files and directories.', + ) + parser.add_argument('--version', '-V', action='version', version=f'%(prog)s v{VERSION}') + arg_namespace = parser.parse_args(args) + + if arg_namespace.verbose: + log_level = logging.DEBUG + else: + log_level = getattr(logging, arg_namespace.verbosity.upper()) + + hdlr = logging.StreamHandler() + hdlr.setLevel(log_level) + hdlr.setFormatter(logging.Formatter(fmt='[%(asctime)s] %(message)s', datefmt='%H:%M:%S')) + wg_logger = logging.getLogger('watchfiles') + wg_logger.addHandler(hdlr) + wg_logger.setLevel(log_level) + + if arg_namespace.target_type == 'auto': + target_type = detect_target_type(arg_namespace.target) + else: + target_type = arg_namespace.target_type + + if target_type == 'function': + logger.debug('target_type=function, attempting import of "%s"', arg_namespace.target) + import_exit(arg_namespace.target) + if arg_namespace.args: + sys.argv = [arg_namespace.target] + shlex.split(arg_namespace.args) + elif arg_namespace.args: + logger.warning('--args is only used when the target is a function') + + try: + paths = [resolve_path(p) for p in arg_namespace.paths] + except FileNotFoundError as e: + print(f'path "{e}" does not exist', file=sys.stderr) + sys.exit(1) + + watch_filter, watch_filter_str = build_filter(arg_namespace.filter, arg_namespace.ignore_paths) + + logger.info( + 'watchfiles v%s 👀 path=%s target="%s" (%s) filter=%s...', + VERSION, + ', '.join(f'"{p}"' for p in paths), + arg_namespace.target, + target_type, + watch_filter_str, + ) + + run_process( + *paths, + target=arg_namespace.target, + target_type=target_type, + watch_filter=watch_filter, + debug=log_level == logging.DEBUG, + sigint_timeout=arg_namespace.sigint_timeout, + sigkill_timeout=arg_namespace.sigkill_timeout, + recursive=not arg_namespace.non_recursive, + ignore_permission_denied=arg_namespace.ignore_permission_denied, + grace_period=arg_namespace.grace_period, + ) + + +def import_exit(function_path: str) -> Any: + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.append(cwd) + + try: + return import_string(function_path) + except ImportError as e: + print(f'ImportError: {e}', file=sys.stderr) + sys.exit(1) + + +def build_filter( + filter_name: str, ignore_paths_str: Optional[str] +) -> Tuple[Union[None, DefaultFilter, Callable[[Change, str], bool]], str]: + ignore_paths: List[Path] = [] + if ignore_paths_str: + ignore_paths = [Path(p).resolve() for p in ignore_paths_str.split(',')] + + if filter_name == 'default': + return DefaultFilter(ignore_paths=ignore_paths), 'DefaultFilter' + elif filter_name == 'python': + return PythonFilter(ignore_paths=ignore_paths), 'PythonFilter' + elif filter_name == 'all': + if ignore_paths: + logger.warning('"--ignore-paths" argument ignored as "all" filter was selected') + return None, '(no filter)' + + watch_filter_cls = import_exit(filter_name) + if isinstance(watch_filter_cls, type) and issubclass(watch_filter_cls, DefaultFilter): + return watch_filter_cls(ignore_paths=ignore_paths), watch_filter_cls.__name__ + + if ignore_paths: + logger.warning('"--ignore-paths" argument ignored as filter is not a subclass of DefaultFilter') + + if isinstance(watch_filter_cls, type) and issubclass(watch_filter_cls, BaseFilter): + return watch_filter_cls(), watch_filter_cls.__name__ + else: + watch_filter = cast(Callable[[Change, str], bool], watch_filter_cls) + return watch_filter, repr(watch_filter_cls) diff --git a/.venv/lib/python3.11/site-packages/watchfiles/filters.py b/.venv/lib/python3.11/site-packages/watchfiles/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..d97dfe8743706d0f45001a92049cafb2e4f1703c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/filters.py @@ -0,0 +1,149 @@ +import logging +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Sequence, Union + +__all__ = 'BaseFilter', 'DefaultFilter', 'PythonFilter' +logger = logging.getLogger('watchfiles.watcher') + + +if TYPE_CHECKING: + from .main import Change + + +class BaseFilter: + """ + Useful base class for creating filters. `BaseFilter` should be inherited and configured, rather than used + directly. + + The class supports ignoring files in 3 ways: + """ + + __slots__ = '_ignore_dirs', '_ignore_entity_regexes', '_ignore_paths' + ignore_dirs: Sequence[str] = () + """Full names of directories to ignore, an obvious example would be `.git`.""" + ignore_entity_patterns: Sequence[str] = () + """ + Patterns of files or directories to ignore, these are compiled into regexes. + + "entity" here refers to the specific file or directory - basically the result of `path.split(os.sep)[-1]`, + an obvious example would be `r'\\.py[cod]$'`. + """ + ignore_paths: Sequence[Union[str, Path]] = () + """ + Full paths to ignore, e.g. `/home/users/.cache` or `C:\\Users\\user\\.cache`. + """ + + def __init__(self) -> None: + self._ignore_dirs = set(self.ignore_dirs) + self._ignore_entity_regexes = tuple(re.compile(r) for r in self.ignore_entity_patterns) + self._ignore_paths = tuple(map(str, self.ignore_paths)) + + def __call__(self, change: 'Change', path: str) -> bool: + """ + Instances of `BaseFilter` subclasses can be used as callables. + Args: + change: The type of change that occurred, see [`Change`][watchfiles.Change]. + path: the raw path of the file or directory that changed. + + Returns: + True if the file should be included in changes, False if it should be ignored. + """ + parts = path.lstrip(os.sep).split(os.sep) + if any(p in self._ignore_dirs for p in parts): + return False + + entity_name = parts[-1] + if any(r.search(entity_name) for r in self._ignore_entity_regexes): + return False + elif self._ignore_paths and path.startswith(self._ignore_paths): + return False + else: + return True + + def __repr__(self) -> str: + args = ', '.join(f'{k}={getattr(self, k, None)!r}' for k in self.__slots__) + return f'{self.__class__.__name__}({args})' + + +class DefaultFilter(BaseFilter): + """ + The default filter, which ignores files and directories that you might commonly want to ignore. + """ + + ignore_dirs: Sequence[str] = ( + '__pycache__', + '.git', + '.hg', + '.svn', + '.tox', + '.venv', + '.idea', + 'node_modules', + '.mypy_cache', + '.pytest_cache', + '.hypothesis', + ) + """Directory names to ignore.""" + + ignore_entity_patterns: Sequence[str] = ( + r'\.py[cod]$', + r'\.___jb_...___$', + r'\.sw.$', + '~$', + r'^\.\#', + r'^\.DS_Store$', + r'^flycheck_', + ) + """File/Directory name patterns to ignore.""" + + def __init__( + self, + *, + ignore_dirs: Optional[Sequence[str]] = None, + ignore_entity_patterns: Optional[Sequence[str]] = None, + ignore_paths: Optional[Sequence[Union[str, Path]]] = None, + ) -> None: + """ + Args: + ignore_dirs: if not `None`, overrides the `ignore_dirs` value set on the class. + ignore_entity_patterns: if not `None`, overrides the `ignore_entity_patterns` value set on the class. + ignore_paths: if not `None`, overrides the `ignore_paths` value set on the class. + """ + if ignore_dirs is not None: + self.ignore_dirs = ignore_dirs + if ignore_entity_patterns is not None: + self.ignore_entity_patterns = ignore_entity_patterns + if ignore_paths is not None: + self.ignore_paths = ignore_paths + + super().__init__() + + +class PythonFilter(DefaultFilter): + """ + A filter for Python files, since this class inherits from [`DefaultFilter`][watchfiles.DefaultFilter] + it will ignore files and directories that you might commonly want to ignore as well as filtering out + all changes except in Python files (files with extensions `('.py', '.pyx', '.pyd')`). + """ + + def __init__( + self, + *, + ignore_paths: Optional[Sequence[Union[str, Path]]] = None, + extra_extensions: Sequence[str] = (), + ) -> None: + """ + Args: + ignore_paths: The paths to ignore, see [`BaseFilter`][watchfiles.BaseFilter]. + extra_extensions: extra extensions to ignore. + + `ignore_paths` and `extra_extensions` can be passed as arguments partly to support [CLI](../cli.md) usage where + `--ignore-paths` and `--extensions` can be passed as arguments. + """ + self.extensions = ('.py', '.pyx', '.pyd') + tuple(extra_extensions) + super().__init__(ignore_paths=ignore_paths) + + def __call__(self, change: 'Change', path: str) -> bool: + return path.endswith(self.extensions) and super().__call__(change, path) diff --git a/.venv/lib/python3.11/site-packages/watchfiles/main.py b/.venv/lib/python3.11/site-packages/watchfiles/main.py new file mode 100644 index 0000000000000000000000000000000000000000..5af1363fb3f86f8f397d693c15f5822263c33822 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/main.py @@ -0,0 +1,373 @@ +import logging +import os +import sys +import warnings +from enum import IntEnum +from pathlib import Path +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Optional, Set, Tuple, Union + +import anyio + +from ._rust_notify import RustNotify +from .filters import DefaultFilter + +__all__ = 'watch', 'awatch', 'Change', 'FileChange' +logger = logging.getLogger('watchfiles.main') + + +class Change(IntEnum): + """ + Enum representing the type of change that occurred. + """ + + added = 1 + """A new file or directory was added.""" + modified = 2 + """A file or directory was modified, can be either a metadata or data change.""" + deleted = 3 + """A file or directory was deleted.""" + + def raw_str(self) -> str: + return self.name + + +FileChange = Tuple[Change, str] +""" +A tuple representing a file change, first element is a [`Change`][watchfiles.Change] member, second is the path +of the file or directory that changed. +""" + +if TYPE_CHECKING: + import asyncio + from typing import Protocol + + import trio + + AnyEvent = Union[anyio.Event, asyncio.Event, trio.Event] + + class AbstractEvent(Protocol): + def is_set(self) -> bool: ... + + +def watch( + *paths: Union[Path, str], + watch_filter: Optional[Callable[['Change', str], bool]] = DefaultFilter(), + debounce: int = 1_600, + step: int = 50, + stop_event: Optional['AbstractEvent'] = None, + rust_timeout: int = 5_000, + yield_on_timeout: bool = False, + debug: Optional[bool] = None, + raise_interrupt: bool = True, + force_polling: Optional[bool] = None, + poll_delay_ms: int = 300, + recursive: bool = True, + ignore_permission_denied: Optional[bool] = None, +) -> Generator[Set[FileChange], None, None]: + """ + Watch one or more paths and yield a set of changes whenever files change. + + The paths watched can be directories or files, directories are watched recursively - changes in subdirectories + are also detected. + + #### Force polling + + Notify will fall back to file polling if it can't use file system notifications, but we also force Notify + to use polling if the `force_polling` argument is `True`; if `force_polling` is unset (or `None`), we enable + force polling thus: + + * if the `WATCHFILES_FORCE_POLLING` environment variable exists and is not empty: + * if the value is `false`, `disable` or `disabled`, force polling is disabled + * otherwise, force polling is enabled + * otherwise, we enable force polling only if we detect we're running on WSL (Windows Subsystem for Linux) + + It is also possible to change the poll delay between iterations, it can be changed to maintain a good response time + and an appropiate CPU consumption using the `poll_delay_ms` argument, we change poll delay thus: + + * if file polling is enabled and the `WATCHFILES_POLL_DELAY_MS` env var exists and it is numeric, we use that + * otherwise, we use the argument value + + Args: + *paths: filesystem paths to watch. + watch_filter: callable used to filter out changes which are not important, you can either use a raw callable + or a [`BaseFilter`][watchfiles.BaseFilter] instance, + defaults to an instance of [`DefaultFilter`][watchfiles.DefaultFilter]. To keep all changes, use `None`. + debounce: maximum time in milliseconds to group changes over before yielding them. + step: time to wait for new changes in milliseconds, if no changes are detected in this time, and + at least one change has been detected, the changes are yielded. + stop_event: event to stop watching, if this is set, the generator will stop iteration, + this can be anything with an `is_set()` method which returns a bool, e.g. `threading.Event()`. + rust_timeout: maximum time in milliseconds to wait in the rust code for changes, `0` means no timeout. + yield_on_timeout: if `True`, the generator will yield upon timeout in rust even if no changes are detected. + debug: whether to print information about all filesystem changes in rust to stdout, if `None` will use the + `WATCHFILES_DEBUG` environment variable. + raise_interrupt: whether to re-raise `KeyboardInterrupt`s, or suppress the error and just stop iterating. + force_polling: See [Force polling](#force-polling) above. + poll_delay_ms: delay between polling for changes, only used if `force_polling=True`. + recursive: if `True`, watch for changes in sub-directories recursively, otherwise watch only for changes in the + top-level directory, default is `True`. + ignore_permission_denied: if `True`, will ignore permission denied errors, otherwise will raise them by default. + Setting the `WATCHFILES_IGNORE_PERMISSION_DENIED` environment variable will set this value too. + + Yields: + The generator yields sets of [`FileChange`][watchfiles.main.FileChange]s. + + ```py title="Example of watch usage" + from watchfiles import watch + + for changes in watch('./first/dir', './second/dir', raise_interrupt=False): + print(changes) + ``` + """ + force_polling = _default_force_polling(force_polling) + poll_delay_ms = _default_poll_delay_ms(poll_delay_ms) + ignore_permission_denied = _default_ignore_permission_denied(ignore_permission_denied) + debug = _default_debug(debug) + with RustNotify( + [str(p) for p in paths], debug, force_polling, poll_delay_ms, recursive, ignore_permission_denied + ) as watcher: + while True: + raw_changes = watcher.watch(debounce, step, rust_timeout, stop_event) + if raw_changes == 'timeout': + if yield_on_timeout: + yield set() + else: + logger.debug('rust notify timeout, continuing') + elif raw_changes == 'signal': + if raise_interrupt: + raise KeyboardInterrupt + else: + logger.warning('KeyboardInterrupt caught, stopping watch') + return + elif raw_changes == 'stop': + return + else: + changes = _prep_changes(raw_changes, watch_filter) + if changes: + _log_changes(changes) + yield changes + else: + logger.debug('all changes filtered out, raw_changes=%s', raw_changes) + + +async def awatch( # C901 + *paths: Union[Path, str], + watch_filter: Optional[Callable[[Change, str], bool]] = DefaultFilter(), + debounce: int = 1_600, + step: int = 50, + stop_event: Optional['AnyEvent'] = None, + rust_timeout: Optional[int] = None, + yield_on_timeout: bool = False, + debug: Optional[bool] = None, + raise_interrupt: Optional[bool] = None, + force_polling: Optional[bool] = None, + poll_delay_ms: int = 300, + recursive: bool = True, + ignore_permission_denied: Optional[bool] = None, +) -> AsyncGenerator[Set[FileChange], None]: + """ + Asynchronous equivalent of [`watch`][watchfiles.watch] using threads to wait for changes. + Arguments match those of [`watch`][watchfiles.watch] except `stop_event`. + + All async methods use [anyio](https://anyio.readthedocs.io/en/latest/) to run the event loop. + + Unlike [`watch`][watchfiles.watch] `KeyboardInterrupt` cannot be suppressed by `awatch` so they need to be caught + where `asyncio.run` or equivalent is called. + + Args: + *paths: filesystem paths to watch. + watch_filter: matches the same argument of [`watch`][watchfiles.watch]. + debounce: matches the same argument of [`watch`][watchfiles.watch]. + step: matches the same argument of [`watch`][watchfiles.watch]. + stop_event: `anyio.Event` which can be used to stop iteration, see example below. + rust_timeout: matches the same argument of [`watch`][watchfiles.watch], except that `None` means + use `1_000` on Windows and `5_000` on other platforms thus helping with exiting on `Ctrl+C` on Windows, + see [#110](https://github.com/samuelcolvin/watchfiles/issues/110). + yield_on_timeout: matches the same argument of [`watch`][watchfiles.watch]. + debug: matches the same argument of [`watch`][watchfiles.watch]. + raise_interrupt: This is deprecated, `KeyboardInterrupt` will cause this coroutine to be cancelled and then + be raised by the top level `asyncio.run` call or equivalent, and should be caught there. + See [#136](https://github.com/samuelcolvin/watchfiles/issues/136) + force_polling: if true, always use polling instead of file system notifications, default is `None` where + `force_polling` is set to `True` if the `WATCHFILES_FORCE_POLLING` environment variable exists. + poll_delay_ms: delay between polling for changes, only used if `force_polling=True`. + `poll_delay_ms` can be changed via the `WATCHFILES_POLL_DELAY_MS` environment variable. + recursive: if `True`, watch for changes in sub-directories recursively, otherwise watch only for changes in the + top-level directory, default is `True`. + ignore_permission_denied: if `True`, will ignore permission denied errors, otherwise will raise them by default. + Setting the `WATCHFILES_IGNORE_PERMISSION_DENIED` environment variable will set this value too. + + Yields: + The generator yields sets of [`FileChange`][watchfiles.main.FileChange]s. + + ```py title="Example of awatch usage" + import asyncio + from watchfiles import awatch + + async def main(): + async for changes in awatch('./first/dir', './second/dir'): + print(changes) + + if __name__ == '__main__': + try: + asyncio.run(main()) + except KeyboardInterrupt: + print('stopped via KeyboardInterrupt') + ``` + + ```py title="Example of awatch usage with a stop event" + import asyncio + from watchfiles import awatch + + async def main(): + stop_event = asyncio.Event() + + async def stop_soon(): + await asyncio.sleep(3) + stop_event.set() + + stop_soon_task = asyncio.create_task(stop_soon()) + + async for changes in awatch('/path/to/dir', stop_event=stop_event): + print(changes) + + # cleanup by awaiting the (now complete) stop_soon_task + await stop_soon_task + + asyncio.run(main()) + ``` + """ + if raise_interrupt is not None: + warnings.warn( + 'raise_interrupt is deprecated, KeyboardInterrupt will cause this coroutine to be cancelled and then ' + 'be raised by the top level asyncio.run call or equivalent, and should be caught there. See #136.', + DeprecationWarning, + ) + + if stop_event is None: + stop_event_: AnyEvent = anyio.Event() + else: + stop_event_ = stop_event + + force_polling = _default_force_polling(force_polling) + poll_delay_ms = _default_poll_delay_ms(poll_delay_ms) + ignore_permission_denied = _default_ignore_permission_denied(ignore_permission_denied) + debug = _default_debug(debug) + with RustNotify( + [str(p) for p in paths], debug, force_polling, poll_delay_ms, recursive, ignore_permission_denied + ) as watcher: + timeout = _calc_async_timeout(rust_timeout) + CancelledError = anyio.get_cancelled_exc_class() + + while True: + async with anyio.create_task_group() as tg: + try: + raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, timeout, stop_event_) + except (CancelledError, KeyboardInterrupt): + stop_event_.set() + # suppressing KeyboardInterrupt wouldn't stop it getting raised by the top level asyncio.run call + raise + tg.cancel_scope.cancel() + + if raw_changes == 'timeout': + if yield_on_timeout: + yield set() + else: + logger.debug('rust notify timeout, continuing') + elif raw_changes == 'stop': + return + elif raw_changes == 'signal': + # in theory the watch thread should never get a signal + raise RuntimeError('watch thread unexpectedly received a signal') + else: + changes = _prep_changes(raw_changes, watch_filter) + if changes: + _log_changes(changes) + yield changes + else: + logger.debug('all changes filtered out, raw_changes=%s', raw_changes) + + +def _prep_changes( + raw_changes: Set[Tuple[int, str]], watch_filter: Optional[Callable[[Change, str], bool]] +) -> Set[FileChange]: + # if we wanted to be really snazzy, we could move this into rust + changes = {(Change(change), path) for change, path in raw_changes} + if watch_filter: + changes = {c for c in changes if watch_filter(c[0], c[1])} + return changes + + +def _log_changes(changes: Set[FileChange]) -> None: + if logger.isEnabledFor(logging.INFO): # pragma: no branch + count = len(changes) + plural = '' if count == 1 else 's' + if logger.isEnabledFor(logging.DEBUG): + logger.debug('%d change%s detected: %s', count, plural, changes) + else: + logger.info('%d change%s detected', count, plural) + + +def _calc_async_timeout(timeout: Optional[int]) -> int: + """ + see https://github.com/samuelcolvin/watchfiles/issues/110 + """ + if timeout is None: + if sys.platform == 'win32': + return 1_000 + else: + return 5_000 + else: + return timeout + + +def _default_force_polling(force_polling: Optional[bool]) -> bool: + """ + See docstring for `watch` above for details. + + See samuelcolvin/watchfiles#167 and samuelcolvin/watchfiles#187 for discussion and rationale. + """ + if force_polling is not None: + return force_polling + env_var = os.getenv('WATCHFILES_FORCE_POLLING') + if env_var: + return env_var.lower() not in {'false', 'disable', 'disabled'} + else: + return _auto_force_polling() + + +def _default_poll_delay_ms(poll_delay_ms: int) -> int: + """ + See docstring for `watch` above for details. + """ + env_var = os.getenv('WATCHFILES_POLL_DELAY_MS') + if env_var and env_var.isdecimal(): + return int(env_var) + else: + return poll_delay_ms + + +def _default_debug(debug: Optional[bool]) -> bool: + if debug is not None: + return debug + env_var = os.getenv('WATCHFILES_DEBUG') + return bool(env_var) + + +def _auto_force_polling() -> bool: + """ + Whether to auto-enable force polling, it should be enabled automatically only on WSL. + + See samuelcolvin/watchfiles#187 for discussion. + """ + import platform + + uname = platform.uname() + return 'microsoft-standard' in uname.release.lower() and uname.system.lower() == 'linux' + + +def _default_ignore_permission_denied(ignore_permission_denied: Optional[bool]) -> bool: + if ignore_permission_denied is not None: + return ignore_permission_denied + env_var = os.getenv('WATCHFILES_IGNORE_PERMISSION_DENIED') + return bool(env_var) diff --git a/.venv/lib/python3.11/site-packages/watchfiles/py.typed b/.venv/lib/python3.11/site-packages/watchfiles/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..7cd6d6f0321bc62f12078af229fb4d0063b9a951 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The watchfiles package uses inline types. diff --git a/.venv/lib/python3.11/site-packages/watchfiles/run.py b/.venv/lib/python3.11/site-packages/watchfiles/run.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d589ed2b3065fad5ff774ad28f160a5bfcf8ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/run.py @@ -0,0 +1,438 @@ +import contextlib +import json +import logging +import os +import re +import shlex +import signal +import subprocess +import sys +from importlib import import_module +from multiprocessing import get_context +from multiprocessing.context import SpawnProcess +from pathlib import Path +from time import sleep +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union + +import anyio + +from .filters import DefaultFilter +from .main import Change, FileChange, awatch, watch + +if TYPE_CHECKING: + from typing import Literal + +__all__ = 'run_process', 'arun_process', 'detect_target_type', 'import_string' +logger = logging.getLogger('watchfiles.main') + + +def run_process( + *paths: Union[Path, str], + target: Union[str, Callable[..., Any]], + args: Tuple[Any, ...] = (), + kwargs: Optional[Dict[str, Any]] = None, + target_type: "Literal['function', 'command', 'auto']" = 'auto', + callback: Optional[Callable[[Set[FileChange]], None]] = None, + watch_filter: Optional[Callable[[Change, str], bool]] = DefaultFilter(), + grace_period: float = 0, + debounce: int = 1_600, + step: int = 50, + debug: Optional[bool] = None, + sigint_timeout: int = 5, + sigkill_timeout: int = 1, + recursive: bool = True, + ignore_permission_denied: bool = False, +) -> int: + """ + Run a process and restart it upon file changes. + + `run_process` can work in two ways: + + * Using `multiprocessing.Process` † to run a python function + * Or, using `subprocess.Popen` to run a command + + !!! note + + **†** technically `multiprocessing.get_context('spawn').Process` to avoid forking and improve + code reload/import. + + Internally, `run_process` uses [`watch`][watchfiles.watch] with `raise_interrupt=False` so the function + exits cleanly upon `Ctrl+C`. + + Args: + *paths: matches the same argument of [`watch`][watchfiles.watch] + target: function or command to run + args: arguments to pass to `target`, only used if `target` is a function + kwargs: keyword arguments to pass to `target`, only used if `target` is a function + target_type: type of target. Can be `'function'`, `'command'`, or `'auto'` in which case + [`detect_target_type`][watchfiles.run.detect_target_type] is used to determine the type. + callback: function to call on each reload, the function should accept a set of changes as the sole argument + watch_filter: matches the same argument of [`watch`][watchfiles.watch] + grace_period: number of seconds after the process is started before watching for changes + debounce: matches the same argument of [`watch`][watchfiles.watch] + step: matches the same argument of [`watch`][watchfiles.watch] + debug: matches the same argument of [`watch`][watchfiles.watch] + sigint_timeout: the number of seconds to wait after sending sigint before sending sigkill + sigkill_timeout: the number of seconds to wait after sending sigkill before raising an exception + recursive: matches the same argument of [`watch`][watchfiles.watch] + + Returns: + number of times the function was reloaded. + + ```py title="Example of run_process running a function" + from watchfiles import run_process + + def callback(changes): + print('changes detected:', changes) + + def foobar(a, b): + print('foobar called with:', a, b) + + if __name__ == '__main__': + run_process('./path/to/dir', target=foobar, args=(1, 2), callback=callback) + ``` + + As well as using a `callback` function, changes can be accessed from within the target function, + using the `WATCHFILES_CHANGES` environment variable. + + ```py title="Example of run_process accessing changes" + from watchfiles import run_process + + def foobar(a, b, c): + # changes will be an empty list "[]" the first time the function is called + changes = os.getenv('WATCHFILES_CHANGES') + changes = json.loads(changes) + print('foobar called due to changes:', changes) + + if __name__ == '__main__': + run_process('./path/to/dir', target=foobar, args=(1, 2, 3)) + ``` + + Again with the target as `command`, `WATCHFILES_CHANGES` can be used + to access changes. + + ```bash title="example.sh" + echo "changers: ${WATCHFILES_CHANGES}" + ``` + + ```py title="Example of run_process running a command" + from watchfiles import run_process + + if __name__ == '__main__': + run_process('.', target='./example.sh') + ``` + """ + if target_type == 'auto': + target_type = detect_target_type(target) + + logger.debug('running "%s" as %s', target, target_type) + catch_sigterm() + process = start_process(target, target_type, args, kwargs) + reloads = 0 + + if grace_period: + logger.debug('sleeping for %s seconds before watching for changes', grace_period) + sleep(grace_period) + + try: + for changes in watch( + *paths, + watch_filter=watch_filter, + debounce=debounce, + step=step, + debug=debug, + raise_interrupt=False, + recursive=recursive, + ignore_permission_denied=ignore_permission_denied, + ): + callback and callback(changes) + process.stop(sigint_timeout=sigint_timeout, sigkill_timeout=sigkill_timeout) + process = start_process(target, target_type, args, kwargs, changes) + reloads += 1 + finally: + process.stop() + return reloads + + +async def arun_process( + *paths: Union[Path, str], + target: Union[str, Callable[..., Any]], + args: Tuple[Any, ...] = (), + kwargs: Optional[Dict[str, Any]] = None, + target_type: "Literal['function', 'command', 'auto']" = 'auto', + callback: Optional[Callable[[Set[FileChange]], Any]] = None, + watch_filter: Optional[Callable[[Change, str], bool]] = DefaultFilter(), + grace_period: float = 0, + debounce: int = 1_600, + step: int = 50, + debug: Optional[bool] = None, + recursive: bool = True, + ignore_permission_denied: bool = False, +) -> int: + """ + Async equivalent of [`run_process`][watchfiles.run_process], all arguments match those of `run_process` except + `callback` which can be a coroutine. + + Starting and stopping the process and watching for changes is done in a separate thread. + + As with `run_process`, internally `arun_process` uses [`awatch`][watchfiles.awatch], however `KeyboardInterrupt` + cannot be caught and suppressed in `awatch` so these errors need to be caught separately, see below. + + ```py title="Example of arun_process usage" + import asyncio + from watchfiles import arun_process + + async def callback(changes): + await asyncio.sleep(0.1) + print('changes detected:', changes) + + def foobar(a, b): + print('foobar called with:', a, b) + + async def main(): + await arun_process('.', target=foobar, args=(1, 2), callback=callback) + + if __name__ == '__main__': + try: + asyncio.run(main()) + except KeyboardInterrupt: + print('stopped via KeyboardInterrupt') + ``` + """ + import inspect + + if target_type == 'auto': + target_type = detect_target_type(target) + + logger.debug('running "%s" as %s', target, target_type) + catch_sigterm() + process = await anyio.to_thread.run_sync(start_process, target, target_type, args, kwargs) + reloads = 0 + + if grace_period: + logger.debug('sleeping for %s seconds before watching for changes', grace_period) + await anyio.sleep(grace_period) + + async for changes in awatch( + *paths, + watch_filter=watch_filter, + debounce=debounce, + step=step, + debug=debug, + recursive=recursive, + ignore_permission_denied=ignore_permission_denied, + ): + if callback is not None: + r = callback(changes) + if inspect.isawaitable(r): + await r + + await anyio.to_thread.run_sync(process.stop) + process = await anyio.to_thread.run_sync(start_process, target, target_type, args, kwargs, changes) + reloads += 1 + await anyio.to_thread.run_sync(process.stop) + return reloads + + +# Use spawn context to make sure code run in subprocess +# does not reuse imported modules in main process/context +spawn_context = get_context('spawn') + + +def split_cmd(cmd: str) -> List[str]: + import platform + + posix = platform.uname().system.lower() != 'windows' + return shlex.split(cmd, posix=posix) + + +def start_process( + target: Union[str, Callable[..., Any]], + target_type: "Literal['function', 'command']", + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]], + changes: Optional[Set[FileChange]] = None, +) -> 'CombinedProcess': + if changes is None: + changes_env_var = '[]' + else: + changes_env_var = json.dumps([[c.raw_str(), p] for c, p in changes]) + + os.environ['WATCHFILES_CHANGES'] = changes_env_var + + process: Union[SpawnProcess, subprocess.Popen[bytes]] + if target_type == 'function': + kwargs = kwargs or {} + if isinstance(target, str): + args = target, get_tty_path(), args, kwargs + target_ = run_function + kwargs = {} + else: + target_ = target + + process = spawn_context.Process(target=target_, args=args, kwargs=kwargs) + process.start() + else: + if args or kwargs: + logger.warning('ignoring args and kwargs for "command" target') + + assert isinstance(target, str), 'target must be a string to run as a command' + popen_args = split_cmd(target) + process = subprocess.Popen(popen_args) + return CombinedProcess(process) + + +def detect_target_type(target: Union[str, Callable[..., Any]]) -> "Literal['function', 'command']": + """ + Used by [`run_process`][watchfiles.run_process], [`arun_process`][watchfiles.arun_process] + and indirectly the CLI to determine the target type with `target_type` is `auto`. + + Detects the target type - either `function` or `command`. This method is only called with `target_type='auto'`. + + The following logic is employed: + + * If `target` is not a string, it is assumed to be a function + * If `target` ends with `.py` or `.sh`, it is assumed to be a command + * Otherwise, the target is assumed to be a function if it matches the regex `[a-zA-Z0-9_]+(\\.[a-zA-Z0-9_]+)+` + + If this logic does not work for you, specify the target type explicitly using the `target_type` function argument + or `--target-type` command line argument. + + Args: + target: The target value + + Returns: + either `'function'` or `'command'` + """ + if not isinstance(target, str): + return 'function' + elif target.endswith(('.py', '.sh')): + return 'command' + elif re.fullmatch(r'[a-zA-Z0-9_]+(\.[a-zA-Z0-9_]+)+', target): + return 'function' + else: + return 'command' + + +class CombinedProcess: + def __init__(self, p: 'Union[SpawnProcess, subprocess.Popen[bytes]]'): + self._p = p + assert self.pid is not None, 'process not yet spawned' + + def stop(self, sigint_timeout: int = 5, sigkill_timeout: int = 1) -> None: + os.environ.pop('WATCHFILES_CHANGES', None) + if self.is_alive(): + logger.debug('stopping process...') + + os.kill(self.pid, signal.SIGINT) + + try: + self.join(sigint_timeout) + except subprocess.TimeoutExpired: + # Capture this exception to allow the self.exitcode to be reached. + # This will allow the SIGKILL to be sent, otherwise it is swallowed up. + logger.warning('SIGINT timed out after %r seconds', sigint_timeout) + pass + + if self.exitcode is None: + logger.warning('process has not terminated, sending SIGKILL') + os.kill(self.pid, signal.SIGKILL) + self.join(sigkill_timeout) + else: + logger.debug('process stopped') + else: + logger.warning('process already dead, exit code: %d', self.exitcode) + + def is_alive(self) -> bool: + if isinstance(self._p, SpawnProcess): + return self._p.is_alive() + else: + return self._p.poll() is None + + @property + def pid(self) -> int: + # we check the process has always been spawned when CombinedProcess is initialised + return self._p.pid # type: ignore[return-value] + + def join(self, timeout: int) -> None: + if isinstance(self._p, SpawnProcess): + self._p.join(timeout) + else: + self._p.wait(timeout) + + @property + def exitcode(self) -> Optional[int]: + if isinstance(self._p, SpawnProcess): + return self._p.exitcode + else: + return self._p.returncode + + +def run_function(function: str, tty_path: Optional[str], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + with set_tty(tty_path): + func = import_string(function) + func(*args, **kwargs) + + +def import_string(dotted_path: str) -> Any: + """ + Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the + last name in the path. Raise ImportError if the import fails. + """ + try: + module_path, class_name = dotted_path.strip(' ').rsplit('.', 1) + except ValueError as e: + raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e + + module = import_module(module_path) + try: + return getattr(module, class_name) + except AttributeError as e: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e + + +def get_tty_path() -> Optional[str]: # pragma: no cover + """ + Return the path to the current TTY, if any. + + Virtually impossible to test in pytest, hence no cover. + """ + try: + return os.ttyname(sys.stdin.fileno()) + except OSError: + # fileno() always fails with pytest + return '/dev/tty' + except AttributeError: + # on Windows. No idea of a better solution + return None + + +@contextlib.contextmanager +def set_tty(tty_path: Optional[str]) -> Generator[None, None, None]: + if tty_path: + try: + with open(tty_path) as tty: # pragma: no cover + sys.stdin = tty + yield + except OSError: + # eg. "No such device or address: '/dev/tty'", see https://github.com/samuelcolvin/watchfiles/issues/40 + yield + else: + # currently on windows tty_path is None and there's nothing we can do here + yield + + +def raise_keyboard_interrupt(signum: int, _frame: Any) -> None: # pragma: no cover + logger.warning('received signal %s, raising KeyboardInterrupt', signal.Signals(signum)) + raise KeyboardInterrupt + + +def catch_sigterm() -> None: + """ + Catch SIGTERM and raise KeyboardInterrupt instead. This means watchfiles will stop quickly + on `docker compose stop` and other cases where SIGTERM is sent. + + Without this the watchfiles process will be killed while a running process will continue uninterrupted. + """ + logger.debug('registering handler for SIGTERM on watchfiles process %d', os.getpid()) + signal.signal(signal.SIGTERM, raise_keyboard_interrupt) diff --git a/.venv/lib/python3.11/site-packages/watchfiles/version.py b/.venv/lib/python3.11/site-packages/watchfiles/version.py new file mode 100644 index 0000000000000000000000000000000000000000..f55721f14ffbf4bdff2e62c701ddf86a3b044483 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/watchfiles/version.py @@ -0,0 +1,5 @@ +from ._rust_notify import __version__ + +__all__ = ('VERSION',) + +VERSION = __version__