diff --git a/.gitattributes b/.gitattributes index 65584d1c6e10ad6b9324a0a08a2dc06e2c74315b..579d7037e08df30397112690af400a5287e636fa 100644 --- a/.gitattributes +++ b/.gitattributes @@ -207,3 +207,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/wrapt/_wrappers.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/__pycache__/typing_extensions.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-x86_64-linux-gnu.so b/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..6c198ad8c5b46411e7ab4d069d3caa86d0e3eab8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/rpds/rpds.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aded6ee5bd881096565cbd54a06c9cb432b1ec9e69b4fdc29f37f32c4573be16 +size 1015312 diff --git a/.venv/lib/python3.11/site-packages/starlette/__init__.py b/.venv/lib/python3.11/site-packages/starlette/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd289900246260c4fbcee5602428d7e4a98d8a3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/__init__.py @@ -0,0 +1 @@ +__version__ = "0.45.3" diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705d8457859082d1dbd6da4f39aed1a6fc1faba6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/_exception_handler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/_exception_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d6fb21c513c90da6ba8cb079b78d44dbae14a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/_exception_handler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da94ecd0d204aa9a4cdae27f501320b8b98aa7e5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/applications.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/applications.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d692306ef799db67ed8d70116c4e5525531ce58c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/applications.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/authentication.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/authentication.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc6f8d3275a34d88838393b41575a9496d852207 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/authentication.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/background.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/background.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41112e51c0306562963652cd01408654b5aa7b82 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/background.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/concurrency.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/concurrency.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db95e505123863419b81e5ddcd789b206c12fcb9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/concurrency.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b32bce441f2cc921ab06638533b36212617cdb6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/convertors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/convertors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..270302ba2eb6e3da2000e24f8e116993eb4fb75a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/convertors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/datastructures.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/datastructures.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3347d5d434a20e6a266b9ca4467e517587075f03 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/datastructures.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/endpoints.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/endpoints.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30da299826accc9da32c1d107c8c7fcec7f078cf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/endpoints.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..061536f9ae2f0f5e79a6773152a0f6d4a92198b1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/formparsers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/formparsers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2d4e89b789ffcba3c5b7f60f652091e790761c1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/formparsers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/requests.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/requests.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d9b6af3fc64183b18716a62d953b727669cf4f7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/requests.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/responses.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/responses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0b961d29195311950a8c4208c9dd03f86878656 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/responses.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/routing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/routing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdb7308e4fc891ae75be89916206a6022ac76df7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/routing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/schemas.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/schemas.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bff1d6af4e80f12d2f0353e43cf13e5a3fce584d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/schemas.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/staticfiles.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/staticfiles.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a68e0a839053b2c773b0295c3d01ad5fdf87574 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/staticfiles.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/status.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/status.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1474b2522abf8771c2347ec29803c86e1a258ce4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/status.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/templating.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/templating.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..848288e6a2ce00606bc044c67a745df18d115506 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/templating.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/testclient.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/testclient.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d22079c12def52258aef631793c9e19026d57a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/testclient.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7ab8eab285f04418b0511a4c08f2193d61ccde1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/__pycache__/websockets.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/__pycache__/websockets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2732347c533c3c1725263ed7e06f7c3ebddc1db0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/__pycache__/websockets.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/_utils.py b/.venv/lib/python3.11/site-packages/starlette/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c389dcb29a0b021e013c5a79d1da8e0eb9eb47f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/_utils.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import asyncio +import functools +import sys +import typing +from contextlib import contextmanager + +from starlette.types import Scope + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: # pragma: no cover + from typing_extensions import TypeGuard + +has_exceptiongroups = True +if sys.version_info < (3, 11): # pragma: no cover + try: + from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] + except ImportError: + has_exceptiongroups = False + +T = typing.TypeVar("T") +AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] + + +@typing.overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@typing.overload +def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ... + + +def is_async_callable(obj: typing.Any) -> typing.Any: + while isinstance(obj, functools.partial): + obj = obj.func + + return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__)) + + +T_co = typing.TypeVar("T_co", covariant=True) + + +class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ... + + +class SupportsAsyncClose(typing.Protocol): + async def close(self) -> None: ... # pragma: no cover + + +SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False) + + +class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]): + __slots__ = ("aw", "entered") + + def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None: + self.aw = aw + + def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]: + return self.aw.__await__() + + async def __aenter__(self) -> SupportsAsyncCloseType: + self.entered = await self.aw + return self.entered + + async def __aexit__(self, *args: typing.Any) -> None | bool: + await self.entered.close() + return None + + +@contextmanager +def collapse_excgroups() -> typing.Generator[None, None, None]: + try: + yield + except BaseException as exc: + if has_exceptiongroups: # pragma: no cover + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc + + +def get_route_path(scope: Scope) -> str: + path: str = scope["path"] + root_path = scope.get("root_path", "") + if not root_path: + return path + + if not path.startswith(root_path): + return path + + if path == root_path: + return "" + + if path[len(root_path)] == "/": + return path[len(root_path) :] + + return path diff --git a/.venv/lib/python3.11/site-packages/starlette/authentication.py b/.venv/lib/python3.11/site-packages/starlette/authentication.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd866412b5e32fc333866cfea8271f3a7116907 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/authentication.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import functools +import inspect +import sys +import typing +from urllib.parse import urlencode + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette._utils import is_async_callable +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection, Request +from starlette.responses import RedirectResponse +from starlette.websockets import WebSocket + +_P = ParamSpec("_P") + + +def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: + for scope in scopes: + if scope not in conn.auth.scopes: + return False + return True + + +def requires( + scopes: str | typing.Sequence[str], + status_code: int = 403, + redirect: str | None = None, +) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]: + scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) + + def decorator( + func: typing.Callable[_P, typing.Any], + ) -> typing.Callable[_P, typing.Any]: + sig = inspect.signature(func) + for idx, parameter in enumerate(sig.parameters.values()): + if parameter.name == "request" or parameter.name == "websocket": + type_ = parameter.name + break + else: + raise Exception(f'No "request" or "websocket" argument on function "{func}"') + + if type_ == "websocket": + # Handle websocket functions. (Always async) + @functools.wraps(func) + async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None) + assert isinstance(websocket, WebSocket) + + if not has_required_scope(websocket, scopes_list): + await websocket.close() + else: + await func(*args, **kwargs) + + return websocket_wrapper + + elif is_async_callable(func): + # Handle async request/response functions. + @functools.wraps(func) + async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + request = kwargs.get("request", args[idx] if idx < len(args) else None) + assert isinstance(request, Request) + + if not has_required_scope(request, scopes_list): + if redirect is not None: + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" + return RedirectResponse(url=next_url, status_code=303) + raise HTTPException(status_code=status_code) + return await func(*args, **kwargs) + + return async_wrapper + + else: + # Handle sync request/response functions. + @functools.wraps(func) + def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + request = kwargs.get("request", args[idx] if idx < len(args) else None) + assert isinstance(request, Request) + + if not has_required_scope(request, scopes_list): + if redirect is not None: + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" + return RedirectResponse(url=next_url, status_code=303) + raise HTTPException(status_code=status_code) + return func(*args, **kwargs) + + return sync_wrapper + + return decorator + + +class AuthenticationError(Exception): + pass + + +class AuthenticationBackend: + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: + raise NotImplementedError() # pragma: no cover + + +class AuthCredentials: + def __init__(self, scopes: typing.Sequence[str] | None = None): + self.scopes = [] if scopes is None else list(scopes) + + +class BaseUser: + @property + def is_authenticated(self) -> bool: + raise NotImplementedError() # pragma: no cover + + @property + def display_name(self) -> str: + raise NotImplementedError() # pragma: no cover + + @property + def identity(self) -> str: + raise NotImplementedError() # pragma: no cover + + +class SimpleUser(BaseUser): + def __init__(self, username: str) -> None: + self.username = username + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.username + + +class UnauthenticatedUser(BaseUser): + @property + def is_authenticated(self) -> bool: + return False + + @property + def display_name(self) -> str: + return "" diff --git a/.venv/lib/python3.11/site-packages/starlette/background.py b/.venv/lib/python3.11/site-packages/starlette/background.py new file mode 100644 index 0000000000000000000000000000000000000000..0430fc08bb6b256767b8511220e89ae9373fa53f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/background.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import sys +import typing + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool + +P = ParamSpec("P") + + +class BackgroundTask: + def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + self.is_async = is_async_callable(func) + + async def __call__(self) -> None: + if self.is_async: + await self.func(*self.args, **self.kwargs) + else: + await run_in_threadpool(self.func, *self.args, **self.kwargs) + + +class BackgroundTasks(BackgroundTask): + def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None): + self.tasks = list(tasks) if tasks else [] + + def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: + task = BackgroundTask(func, *args, **kwargs) + self.tasks.append(task) + + async def __call__(self) -> None: + for task in self.tasks: + await task() diff --git a/.venv/lib/python3.11/site-packages/starlette/endpoints.py b/.venv/lib/python3.11/site-packages/starlette/endpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..107690266e16e85add5b7960e1915d153ceb8ccc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/endpoints.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import json +import typing + +from starlette import status +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import Message, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class HTTPEndpoint: + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + self.scope = scope + self.receive = receive + self.send = send + self._allowed_methods = [ + method + for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") + if getattr(self, method.lower(), None) is not None + ] + + def __await__(self) -> typing.Generator[typing.Any, None, None]: + return self.dispatch().__await__() + + async def dispatch(self) -> None: + request = Request(self.scope, receive=self.receive) + handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() + + handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed) + is_async = is_async_callable(handler) + if is_async: + response = await handler(request) + else: + response = await run_in_threadpool(handler, request) + await response(self.scope, self.receive, self.send) + + async def method_not_allowed(self, request: Request) -> Response: + # If we're running inside a starlette application then raise an + # exception, so that the configurable exception handler can deal with + # returning the response. For plain ASGI apps, just return the response. + headers = {"Allow": ", ".join(self._allowed_methods)} + if "app" in self.scope: + raise HTTPException(status_code=405, headers=headers) + return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) + + +class WebSocketEndpoint: + encoding: str | None = None # May be "text", "bytes", or "json". + + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "websocket" + self.scope = scope + self.receive = receive + self.send = send + + def __await__(self) -> typing.Generator[typing.Any, None, None]: + return self.dispatch().__await__() + + async def dispatch(self) -> None: + websocket = WebSocket(self.scope, receive=self.receive, send=self.send) + await self.on_connect(websocket) + + close_code = status.WS_1000_NORMAL_CLOSURE + + try: + while True: + message = await websocket.receive() + if message["type"] == "websocket.receive": + data = await self.decode(websocket, message) + await self.on_receive(websocket, data) + elif message["type"] == "websocket.disconnect": # pragma: no branch + close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) + break + except Exception as exc: + close_code = status.WS_1011_INTERNAL_ERROR + raise exc + finally: + await self.on_disconnect(websocket, close_code) + + async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: + if self.encoding == "text": + if "text" not in message: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Expected text websocket messages, but got bytes") + return message["text"] + + elif self.encoding == "bytes": + if "bytes" not in message: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Expected bytes websocket messages, but got text") + return message["bytes"] + + elif self.encoding == "json": + if message.get("text") is not None: + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + + try: + return json.loads(text) + except json.decoder.JSONDecodeError: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Malformed JSON data received.") + + assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}" + return message["text"] if message.get("text") else message["bytes"] + + async def on_connect(self, websocket: WebSocket) -> None: + """Override to handle an incoming websocket connection""" + await websocket.accept() + + async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: + """Override to handle an incoming websocket message""" + + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: + """Override to handle a disconnecting websocket""" diff --git a/.venv/lib/python3.11/site-packages/starlette/formparsers.py b/.venv/lib/python3.11/site-packages/starlette/formparsers.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff1523b39468dc9ff84bb0d1a58244674eb61bd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/formparsers.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass, field +from enum import Enum +from tempfile import SpooledTemporaryFile +from urllib.parse import unquote_plus + +from starlette.datastructures import FormData, Headers, UploadFile + +if typing.TYPE_CHECKING: + import python_multipart as multipart + from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header +else: + try: + try: + import python_multipart as multipart + from python_multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + import multipart + from multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + multipart = None + parse_options_header = None + + +class FormMessage(Enum): + FIELD_START = 1 + FIELD_NAME = 2 + FIELD_DATA = 3 + FIELD_END = 4 + END = 5 + + +@dataclass +class MultipartPart: + content_disposition: bytes | None = None + field_name: str = "" + data: bytearray = field(default_factory=bytearray) + file: UploadFile | None = None + item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) + + +def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: + try: + return src.decode(codec) + except (UnicodeDecodeError, LookupError): + return src.decode("latin-1") + + +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message + + +class FormParser: + def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.messages: list[tuple[FormMessage, bytes]] = [] + + def on_field_start(self) -> None: + message = (FormMessage.FIELD_START, b"") + self.messages.append(message) + + def on_field_name(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_NAME, data[start:end]) + self.messages.append(message) + + def on_field_data(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_DATA, data[start:end]) + self.messages.append(message) + + def on_field_end(self) -> None: + message = (FormMessage.FIELD_END, b"") + self.messages.append(message) + + def on_end(self) -> None: + message = (FormMessage.END, b"") + self.messages.append(message) + + async def parse(self) -> FormData: + # Callbacks dictionary. + callbacks: QuerystringCallbacks = { + "on_field_start": self.on_field_start, + "on_field_name": self.on_field_name, + "on_field_data": self.on_field_data, + "on_field_end": self.on_field_end, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.QuerystringParser(callbacks) + field_name = b"" + field_value = b"" + + items: list[tuple[str, str | UploadFile]] = [] + + # Feed the parser with data from the request. + async for chunk in self.stream: + if chunk: + parser.write(chunk) + else: + parser.finalize() + messages = list(self.messages) + self.messages.clear() + for message_type, message_bytes in messages: + if message_type == FormMessage.FIELD_START: + field_name = b"" + field_value = b"" + elif message_type == FormMessage.FIELD_NAME: + field_name += message_bytes + elif message_type == FormMessage.FIELD_DATA: + field_value += message_bytes + elif message_type == FormMessage.FIELD_END: + name = unquote_plus(field_name.decode("latin-1")) + value = unquote_plus(field_value.decode("latin-1")) + items.append((name, value)) + + return FormData(items) + + +class MultiPartParser: + max_file_size = 1024 * 1024 # 1MB + + def __init__( + self, + headers: Headers, + stream: typing.AsyncGenerator[bytes, None], + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, # 1MB + ) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.max_files = max_files + self.max_fields = max_fields + self.items: list[tuple[str, str | UploadFile]] = [] + self._current_files = 0 + self._current_fields = 0 + self._current_partial_header_name: bytes = b"" + self._current_partial_header_value: bytes = b"" + self._current_part = MultipartPart() + self._charset = "" + self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: list[MultipartPart] = [] + self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + self.max_part_size = max_part_size + + def on_part_begin(self) -> None: + self._current_part = MultipartPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + message_bytes = data[start:end] + if self._current_part.file is None: + if len(self._current_part.data) + len(message_bytes) > self.max_part_size: + raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") + self._current_part.data.extend(message_bytes) + else: + self._file_parts_to_write.append((self._current_part, message_bytes)) + + def on_part_end(self) -> None: + if self._current_part.file is None: + self.items.append( + ( + self._current_part.field_name, + _user_safe_decode(self._current_part.data, self._charset), + ) + ) + else: + self._file_parts_to_finish.append(self._current_part) + # The file can be added to the items right now even though it's not + # finished yet, because it will be finished in the `parse()` method, before + # self.items is used in the return value. + self.items.append((self._current_part.field_name, self._current_part.file)) + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_value += data[start:end] + + def on_header_end(self) -> None: + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append((field, self._current_partial_header_value)) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + + def on_headers_finished(self) -> None: + disposition, options = parse_options_header(self._current_part.content_disposition) + try: + self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) + except KeyError: + raise MultiPartException('The Content-Disposition header field "name" must be provided.') + if b"filename" in options: + self._current_files += 1 + if self._current_files > self.max_files: + raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") + filename = _user_safe_decode(options[b"filename"], self._charset) + tempfile = SpooledTemporaryFile(max_size=self.max_file_size) + self._files_to_close_on_error.append(tempfile) + self._current_part.file = UploadFile( + file=tempfile, # type: ignore[arg-type] + size=0, + filename=filename, + headers=Headers(raw=self._current_part.item_headers), + ) + else: + self._current_fields += 1 + if self._current_fields > self.max_fields: + raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") + self._current_part.file = None + + def on_end(self) -> None: + pass + + async def parse(self) -> FormData: + # Parse the Content-Type header to get the multipart boundary. + _, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = charset + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") + + # Callbacks dictionary. + callbacks: MultipartCallbacks = { + "on_part_begin": self.on_part_begin, + "on_part_data": self.on_part_data, + "on_part_end": self.on_part_end, + "on_header_field": self.on_header_field, + "on_header_value": self.on_header_value, + "on_header_end": self.on_header_end, + "on_headers_finished": self.on_headers_finished, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.MultipartParser(boundary, callbacks) + try: + # Feed the parser with data from the request. + async for chunk in self.stream: + parser.write(chunk) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + # Close all the files if there was an error. + for file in self._files_to_close_on_error: + file.close() + raise exc + + parser.finalize() + return FormData(self.items) diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__init__.py b/.venv/lib/python3.11/site-packages/starlette/middleware/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b99538a272d985e5288faff04d2e7971b938d613 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/__init__.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterator +from typing import Any, Protocol + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette.types import ASGIApp + +P = ParamSpec("P") + + +class _MiddlewareFactory(Protocol[P]): + def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover + + +class Middleware: + def __init__( + self, + cls: _MiddlewareFactory[P], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + self.cls = cls + self.args = args + self.kwargs = kwargs + + def __iter__(self) -> Iterator[Any]: + as_tuple = (self.cls, self.args, self.kwargs) + return iter(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + args_strings = [f"{value!r}" for value in self.args] + option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] + name = getattr(self.cls, "__name__", "") + args_repr = ", ".join([name] + args_strings + option_strings) + return f"{class_name}({args_repr})" diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194bbc2f04210bff85417f075704f443826347c0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/authentication.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/authentication.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b90b64bb92a7245e856f38c2d62f532b3558fde Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/authentication.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3fdff81288b9e8fecfd45dd6a5979c0fd1de35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/cors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/cors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b6afed59225c29256b82d9a3c850101dfb227c4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/cors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/errors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/errors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89ec8881427645ebaebdb3d843d162497cceb562 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5add17efee2d807b1d4da3949b5cb8ce95826788 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/gzip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/gzip.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38ac32585fe00c5da52148d7e8054e429279dd3c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/gzip.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/httpsredirect.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/httpsredirect.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa41304209c20e4458608d98cbcfec6e9d354fe3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/httpsredirect.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/sessions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/sessions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ff1a4a1ed6f7feade16b5b066b98fef84fb6aad Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/sessions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/trustedhost.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/trustedhost.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7c6be0b8b026269e1e5ce9ec44f57b151c80983 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/trustedhost.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/wsgi.cpython-311.pyc b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/wsgi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c80d46ddc1493b62ac7478f434aba8d1e103e5b3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/starlette/middleware/__pycache__/wsgi.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/authentication.py b/.venv/lib/python3.11/site-packages/starlette/middleware/authentication.py new file mode 100644 index 0000000000000000000000000000000000000000..8555ee0780e98b052eb463d55a1c18e39b257762 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/authentication.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import typing + +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + AuthenticationError, + UnauthenticatedUser, +) +from starlette.requests import HTTPConnection +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + + +class AuthenticationMiddleware: + def __init__( + self, + app: ASGIApp, + backend: AuthenticationBackend, + on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None, + ) -> None: + self.app = app + self.backend = backend + self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = ( + on_error if on_error is not None else self.default_on_error + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ["http", "websocket"]: + await self.app(scope, receive, send) + return + + conn = HTTPConnection(scope) + try: + auth_result = await self.backend.authenticate(conn) + except AuthenticationError as exc: + response = self.on_error(conn, exc) + if scope["type"] == "websocket": + await send({"type": "websocket.close", "code": 1000}) + else: + await response(scope, receive, send) + return + + if auth_result is None: + auth_result = AuthCredentials(), UnauthenticatedUser() + scope["auth"], scope["user"] = auth_result + await self.app(scope, receive, send) + + @staticmethod + def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: + return PlainTextResponse(str(exc), status_code=400) diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/base.py b/.venv/lib/python3.11/site-packages/starlette/middleware/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f146984b3428c5217a9e62aa001ff87f3ae92d2e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/base.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import typing + +import anyio + +from starlette._utils import collapse_excgroups +from starlette.requests import ClientDisconnect, Request +from starlette.responses import AsyncContentStream, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] +DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] +T = typing.TypeVar("T") + + +class _CachedRequest(Request): + """ + If the user calls Request.body() from their dispatch function + we cache the entire request body in memory and pass that to downstream middlewares, + but if they call Request.stream() then all we do is send an + empty body so that downstream things don't hang forever. + """ + + def __init__(self, scope: Scope, receive: Receive): + super().__init__(scope, receive) + self._wrapped_rcv_disconnected = False + self._wrapped_rcv_consumed = False + self._wrapped_rc_stream = self.stream() + + async def wrapped_receive(self) -> Message: + # wrapped_rcv state 1: disconnected + if self._wrapped_rcv_disconnected: + # we've already sent a disconnect to the downstream app + # we don't need to wait to get another one + # (although most ASGI servers will just keep sending it) + return {"type": "http.disconnect"} + # wrapped_rcv state 1: consumed but not yet disconnected + if self._wrapped_rcv_consumed: + # since the downstream app has consumed us all that is left + # is to send it a disconnect + if self._is_disconnected: + # the middleware has already seen the disconnect + # since we know the client is disconnected no need to wait + # for the message + self._wrapped_rcv_disconnected = True + return {"type": "http.disconnect"} + # we don't know yet if the client is disconnected or not + # so we'll wait until we get that message + msg = await self.receive() + if msg["type"] != "http.disconnect": # pragma: no cover + # at this point a disconnect is all that we should be receiving + # if we get something else, things went wrong somewhere + raise RuntimeError(f"Unexpected message received: {msg['type']}") + self._wrapped_rcv_disconnected = True + return msg + + # wrapped_rcv state 3: not yet consumed + if getattr(self, "_body", None) is not None: + # body() was called, we return it even if the client disconnected + self._wrapped_rcv_consumed = True + return { + "type": "http.request", + "body": self._body, + "more_body": False, + } + elif self._stream_consumed: + # stream() was called to completion + # return an empty body so that downstream apps don't hang + # waiting for a disconnect + self._wrapped_rcv_consumed = True + return { + "type": "http.request", + "body": b"", + "more_body": False, + } + else: + # body() was never called and stream() wasn't consumed + try: + stream = self.stream() + chunk = await stream.__anext__() + self._wrapped_rcv_consumed = self._stream_consumed + return { + "type": "http.request", + "body": chunk, + "more_body": not self._stream_consumed, + } + except ClientDisconnect: + self._wrapped_rcv_disconnected = True + return {"type": "http.disconnect"} + + +class BaseHTTPMiddleware: + def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: + self.app = app + self.dispatch_func = self.dispatch if dispatch is None else dispatch + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = _CachedRequest(scope, receive) + wrapped_receive = request.wrapped_receive + response_sent = anyio.Event() + + async def call_next(request: Request) -> Response: + app_exc: Exception | None = None + + async def receive_or_disconnect() -> Message: + if response_sent.is_set(): + return {"type": "http.disconnect"} + + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: + result = await func() + task_group.cancel_scope.cancel() + return result + + task_group.start_soon(wrap, response_sent.wait) + message = await wrap(wrapped_receive) + + if response_sent.is_set(): + return {"type": "http.disconnect"} + + return message + + async def send_no_error(message: Message) -> None: + try: + await send_stream.send(message) + except anyio.BrokenResourceError: + # recv_stream has been closed, i.e. response_sent has been set. + return + + async def coro() -> None: + nonlocal app_exc + + with send_stream: + try: + await self.app(scope, receive_or_disconnect, send_no_error) + except Exception as exc: + app_exc = exc + + task_group.start_soon(coro) + + try: + message = await recv_stream.receive() + info = message.get("info", None) + if message["type"] == "http.response.debug" and info is not None: + message = await recv_stream.receive() + except anyio.EndOfStream: + if app_exc is not None: + raise app_exc + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async for message in recv_stream: + assert message["type"] == "http.response.body" + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break + + if app_exc is not None: + raise app_exc + + response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) + response.raw_headers = message["headers"] + return response + + streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() + send_stream, recv_stream = streams + with recv_stream, send_stream, collapse_excgroups(): + async with anyio.create_task_group() as task_group: + response = await self.dispatch_func(request, call_next) + await response(scope, wrapped_receive, send) + response_sent.set() + recv_stream.close() + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + raise NotImplementedError() # pragma: no cover + + +class _StreamingResponse(Response): + def __init__( + self, + content: AsyncContentStream, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + info: typing.Mapping[str, typing.Any] | None = None, + ) -> None: + self.info = info + self.body_iterator = content + self.status_code = status_code + self.media_type = media_type + self.init_headers(headers) + self.background = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.info is not None: + await send({"type": "http.response.debug", "info": self.info}) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + async for chunk in self.body_iterator: + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + if self.background: + await self.background() diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/cors.py b/.venv/lib/python3.11/site-packages/starlette/middleware/cors.py new file mode 100644 index 0000000000000000000000000000000000000000..61502691abdcde4ce790cb16c4d28002a6241311 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/cors.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import functools +import re +import typing + +from starlette.datastructures import Headers, MutableHeaders +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") +SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} + + +class CORSMiddleware: + def __init__( + self, + app: ASGIApp, + allow_origins: typing.Sequence[str] = (), + allow_methods: typing.Sequence[str] = ("GET",), + allow_headers: typing.Sequence[str] = (), + allow_credentials: bool = False, + allow_origin_regex: str | None = None, + expose_headers: typing.Sequence[str] = (), + max_age: int = 600, + ) -> None: + if "*" in allow_methods: + allow_methods = ALL_METHODS + + compiled_allow_origin_regex = None + if allow_origin_regex is not None: + compiled_allow_origin_regex = re.compile(allow_origin_regex) + + allow_all_origins = "*" in allow_origins + allow_all_headers = "*" in allow_headers + preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + + simple_headers = {} + if allow_all_origins: + simple_headers["Access-Control-Allow-Origin"] = "*" + if allow_credentials: + simple_headers["Access-Control-Allow-Credentials"] = "true" + if expose_headers: + simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) + + preflight_headers = {} + if preflight_explicit_allow_origin: + # The origin value will be set in preflight_response() if it is allowed. + preflight_headers["Vary"] = "Origin" + else: + preflight_headers["Access-Control-Allow-Origin"] = "*" + preflight_headers.update( + { + "Access-Control-Allow-Methods": ", ".join(allow_methods), + "Access-Control-Max-Age": str(max_age), + } + ) + allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) + if allow_headers and not allow_all_headers: + preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) + if allow_credentials: + preflight_headers["Access-Control-Allow-Credentials"] = "true" + + self.app = app + self.allow_origins = allow_origins + self.allow_methods = allow_methods + self.allow_headers = [h.lower() for h in allow_headers] + self.allow_all_origins = allow_all_origins + self.allow_all_headers = allow_all_headers + self.preflight_explicit_allow_origin = preflight_explicit_allow_origin + self.allow_origin_regex = compiled_allow_origin_regex + self.simple_headers = simple_headers + self.preflight_headers = preflight_headers + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + method = scope["method"] + headers = Headers(scope=scope) + origin = headers.get("origin") + + if origin is None: + await self.app(scope, receive, send) + return + + if method == "OPTIONS" and "access-control-request-method" in headers: + response = self.preflight_response(request_headers=headers) + await response(scope, receive, send) + return + + await self.simple_response(scope, receive, send, request_headers=headers) + + def is_allowed_origin(self, origin: str) -> bool: + if self.allow_all_origins: + return True + + if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): + return True + + return origin in self.allow_origins + + def preflight_response(self, request_headers: Headers) -> Response: + requested_origin = request_headers["origin"] + requested_method = request_headers["access-control-request-method"] + requested_headers = request_headers.get("access-control-request-headers") + + headers = dict(self.preflight_headers) + failures = [] + + if self.is_allowed_origin(origin=requested_origin): + if self.preflight_explicit_allow_origin: + # The "else" case is already accounted for in self.preflight_headers + # and the value would be "*". + headers["Access-Control-Allow-Origin"] = requested_origin + else: + failures.append("origin") + + if requested_method not in self.allow_methods: + failures.append("method") + + # If we allow all headers, then we have to mirror back any requested + # headers in the response. + if self.allow_all_headers and requested_headers is not None: + headers["Access-Control-Allow-Headers"] = requested_headers + elif requested_headers is not None: + for header in [h.lower() for h in requested_headers.split(",")]: + if header.strip() not in self.allow_headers: + failures.append("headers") + break + + # We don't strictly need to use 400 responses here, since its up to + # the browser to enforce the CORS policy, but its more informative + # if we do. + if failures: + failure_text = "Disallowed CORS " + ", ".join(failures) + return PlainTextResponse(failure_text, status_code=400, headers=headers) + + return PlainTextResponse("OK", status_code=200, headers=headers) + + async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None: + send = functools.partial(self.send, send=send, request_headers=request_headers) + await self.app(scope, receive, send) + + async def send(self, message: Message, send: Send, request_headers: Headers) -> None: + if message["type"] != "http.response.start": + await send(message) + return + + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + headers.update(self.simple_headers) + origin = request_headers["Origin"] + has_cookie = "cookie" in request_headers + + # If request includes any cookie headers, then we must respond + # with the specific origin instead of '*'. + if self.allow_all_origins and has_cookie: + self.allow_explicit_origin(headers, origin) + + # If we only allow specific origins, then we have to mirror back + # the Origin header in the response. + elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): + self.allow_explicit_origin(headers, origin) + + await send(message) + + @staticmethod + def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: + headers["Access-Control-Allow-Origin"] = origin + headers.add_vary_header("Origin") diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/errors.py b/.venv/lib/python3.11/site-packages/starlette/middleware/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..76ad776be2272b6ee6ad1e5cd948762465bf0dcc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/errors.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import html +import inspect +import sys +import traceback +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request +from starlette.responses import HTMLResponse, PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +STYLES = """ +p { + color: #211c1c; +} +.traceback-container { + border: 1px solid #038BB8; +} +.traceback-title { + background-color: #038BB8; + color: lemonchiffon; + padding: 12px; + font-size: 20px; + margin-top: 0px; +} +.frame-line { + padding-left: 10px; + font-family: monospace; +} +.frame-filename { + font-family: monospace; +} +.center-line { + background-color: #038BB8; + color: #f9f6e1; + padding: 5px 0px 5px 5px; +} +.lineno { + margin-right: 5px; +} +.frame-title { + font-weight: unset; + padding: 10px 10px 10px 10px; + background-color: #E4F4FD; + margin-right: 10px; + color: #191f21; + font-size: 17px; + border: 1px solid #c7dce8; +} +.collapse-btn { + float: right; + padding: 0px 5px 1px 5px; + border: solid 1px #96aebb; + cursor: pointer; +} +.collapsed { + display: none; +} +.source-code { + font-family: courier; + font-size: small; + padding-bottom: 10px; +} +""" + +JS = """ + +""" + +TEMPLATE = """ + + + + Starlette Debugger + + +

500 Server Error

+

{error}

+
+

Traceback

+
{exc_html}
+
+ {js} + + +""" + +FRAME_TEMPLATE = """ +
+

File {frame_filename}, + line {frame_lineno}, + in {frame_name} + {collapse_button} +

+
{code_context}
+
+""" # noqa: E501 + +LINE = """ +

+{lineno}. {line}

+""" + +CENTER_LINE = """ +

+{lineno}. {line}

+""" + + +class ServerErrorMiddleware: + """ + Handles returning 500 responses when a server error occurs. + + If 'debug' is set, then traceback responses will be returned, + otherwise the designated 'handler' will be called. + + This middleware class should generally be used to wrap *everything* + else up, so that unhandled exceptions anywhere in the stack + always result in an appropriate 500 response. + """ + + def __init__( + self, + app: ASGIApp, + handler: typing.Callable[[Request, Exception], typing.Any] | None = None, + debug: bool = False, + ) -> None: + self.app = app + self.handler = handler + self.debug = debug + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + response_started = False + + async def _send(message: Message) -> None: + nonlocal response_started, send + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await self.app(scope, receive, _send) + except Exception as exc: + request = Request(scope) + if self.debug: + # In debug mode, return traceback responses. + response = self.debug_response(request, exc) + elif self.handler is None: + # Use our default 500 error handler. + response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if is_async_callable(self.handler): + response = await self.handler(request, exc) + else: + response = await run_in_threadpool(self.handler, request, exc) + + if not response_started: + await response(scope, receive, send) + + # We always continue to raise the exception. + # This allows servers to log the error, or allows test clients + # to optionally raise the error within the test case. + raise exc + + def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str: + values = { + # HTML escape - line could contain < or > + "line": html.escape(line).replace(" ", " "), + "lineno": (frame_lineno - frame_index) + index, + } + + if index != frame_index: + return LINE.format(**values) + return CENTER_LINE.format(**values) + + def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str: + code_context = "".join( + self.format_line( + index, + line, + frame.lineno, + frame.index, # type: ignore[arg-type] + ) + for index, line in enumerate(frame.code_context or []) + ) + + values = { + # HTML escape - filename could contain < or >, especially if it's a virtual + # file e.g. in the REPL + "frame_filename": html.escape(frame.filename), + "frame_lineno": frame.lineno, + # HTML escape - if you try very hard it's possible to name a function with < + # or > + "frame_name": html.escape(frame.function), + "code_context": code_context, + "collapsed": "collapsed" if is_collapsed else "", + "collapse_button": "+" if is_collapsed else "‒", + } + return FRAME_TEMPLATE.format(**values) + + def generate_html(self, exc: Exception, limit: int = 7) -> str: + traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True) + + exc_html = "" + is_collapsed = False + exc_traceback = exc.__traceback__ + if exc_traceback is not None: + frames = inspect.getinnerframes(exc_traceback, limit) + for frame in reversed(frames): + exc_html += self.generate_frame_html(frame, is_collapsed) + is_collapsed = True + + if sys.version_info >= (3, 13): # pragma: no cover + exc_type_str = traceback_obj.exc_type_str + else: # pragma: no cover + exc_type_str = traceback_obj.exc_type.__name__ + + # escape error class and text + error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}" + + return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html) + + def generate_plain_text(self, exc: Exception) -> str: + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + def debug_response(self, request: Request, exc: Exception) -> Response: + accept = request.headers.get("accept", "") + + if "text/html" in accept: + content = self.generate_html(exc) + return HTMLResponse(content, status_code=500) + content = self.generate_plain_text(exc) + return PlainTextResponse(content, status_code=500) + + def error_response(self, request: Request, exc: Exception) -> Response: + return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/exceptions.py b/.venv/lib/python3.11/site-packages/starlette/middleware/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..981d2fcaef87c2d3fadb54c83afeb88ccdbaf561 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/exceptions.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing + +from starlette._exception_handler import ( + ExceptionHandlers, + StatusHandlers, + wrap_app_handling_exceptions, +) +from starlette.exceptions import HTTPException, WebSocketException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class ExceptionMiddleware: + def __init__( + self, + app: ASGIApp, + handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None, + debug: bool = False, + ) -> None: + self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers: StatusHandlers = {} + self._exception_handlers: ExceptionHandlers = { + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, + } + if handlers is not None: # pragma: no branch + for key, value in handlers.items(): + self.add_exception_handler(key, value) + + def add_exception_handler( + self, + exc_class_or_status_code: int | type[Exception], + handler: typing.Callable[[Request, Exception], Response], + ) -> None: + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + + conn: Request | WebSocket + if scope["type"] == "http": + conn = Request(scope, receive, send) + else: + conn = WebSocket(scope, receive, send) + + await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) + + def http_exception(self, request: Request, exc: Exception) -> Response: + assert isinstance(exc, HTTPException) + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) + + async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: + assert isinstance(exc, WebSocketException) + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/gzip.py b/.venv/lib/python3.11/site-packages/starlette/middleware/gzip.py new file mode 100644 index 0000000000000000000000000000000000000000..b677063da3702bc0d966b324edb9dc23dcff046e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/gzip.py @@ -0,0 +1,108 @@ +import gzip +import io +import typing + +from starlette.datastructures import Headers, MutableHeaders +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class GZipMiddleware: + def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: + self.app = app + self.minimum_size = minimum_size + self.compresslevel = compresslevel + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": # pragma: no branch + headers = Headers(scope=scope) + if "gzip" in headers.get("Accept-Encoding", ""): + responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) + await responder(scope, receive, send) + return + await self.app(scope, receive, send) + + +class GZipResponder: + def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: + self.app = app + self.minimum_size = minimum_size + self.send: Send = unattached_send + self.initial_message: Message = {} + self.started = False + self.content_encoding_set = False + self.gzip_buffer = io.BytesIO() + self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + with self.gzip_buffer, self.gzip_file: + await self.app(scope, receive, self.send_with_gzip) + + async def send_with_gzip(self, message: Message) -> None: + message_type = message["type"] + if message_type == "http.response.start": + # Don't send the initial message until we've determined how to + # modify the outgoing headers correctly. + self.initial_message = message + headers = Headers(raw=self.initial_message["headers"]) + self.content_encoding_set = "content-encoding" in headers + elif message_type == "http.response.body" and self.content_encoding_set: + if not self.started: + self.started = True + await self.send(self.initial_message) + await self.send(message) + elif message_type == "http.response.body" and not self.started: + self.started = True + body = message.get("body", b"") + more_body = message.get("more_body", False) + if len(body) < self.minimum_size and not more_body: + # Don't apply GZip to small outgoing responses. + await self.send(self.initial_message) + await self.send(message) + elif not more_body: + # Standard GZip response. + self.gzip_file.write(body) + self.gzip_file.close() + body = self.gzip_buffer.getvalue() + + headers = MutableHeaders(raw=self.initial_message["headers"]) + headers["Content-Encoding"] = "gzip" + headers["Content-Length"] = str(len(body)) + headers.add_vary_header("Accept-Encoding") + message["body"] = body + + await self.send(self.initial_message) + await self.send(message) + else: + # Initial body in streaming GZip response. + headers = MutableHeaders(raw=self.initial_message["headers"]) + headers["Content-Encoding"] = "gzip" + headers.add_vary_header("Accept-Encoding") + del headers["Content-Length"] + + self.gzip_file.write(body) + message["body"] = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + await self.send(self.initial_message) + await self.send(message) + + elif message_type == "http.response.body": # pragma: no branch + # Remaining body in streaming GZip response. + body = message.get("body", b"") + more_body = message.get("more_body", False) + + self.gzip_file.write(body) + if not more_body: + self.gzip_file.close() + + message["body"] = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + await self.send(message) + + +async def unattached_send(message: Message) -> typing.NoReturn: + raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/httpsredirect.py b/.venv/lib/python3.11/site-packages/starlette/middleware/httpsredirect.py new file mode 100644 index 0000000000000000000000000000000000000000..a8359067ff7afb80e979042077d5fa0fff119ddf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/httpsredirect.py @@ -0,0 +1,19 @@ +from starlette.datastructures import URL +from starlette.responses import RedirectResponse +from starlette.types import ASGIApp, Receive, Scope, Send + + +class HTTPSRedirectMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"): + url = URL(scope=scope) + redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme] + netloc = url.hostname if url.port in (80, 443) else url.netloc + url = url.replace(scheme=redirect_scheme, netloc=netloc) + response = RedirectResponse(url, status_code=307) + await response(scope, receive, send) + else: + await self.app(scope, receive, send) diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/sessions.py b/.venv/lib/python3.11/site-packages/starlette/middleware/sessions.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9fcd883b69fb960b067f316d49401aba425b4c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/sessions.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +import typing +from base64 import b64decode, b64encode + +import itsdangerous +from itsdangerous.exc import BadSignature + +from starlette.datastructures import MutableHeaders, Secret +from starlette.requests import HTTPConnection +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class SessionMiddleware: + def __init__( + self, + app: ASGIApp, + secret_key: str | Secret, + session_cookie: str = "session", + max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds + path: str = "/", + same_site: typing.Literal["lax", "strict", "none"] = "lax", + https_only: bool = False, + domain: str | None = None, + ) -> None: + self.app = app + self.signer = itsdangerous.TimestampSigner(str(secret_key)) + self.session_cookie = session_cookie + self.max_age = max_age + self.path = path + self.security_flags = "httponly; samesite=" + same_site + if https_only: # Secure flag can be used with HTTPS only + self.security_flags += "; secure" + if domain is not None: + self.security_flags += f"; domain={domain}" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): # pragma: no cover + await self.app(scope, receive, send) + return + + connection = HTTPConnection(scope) + initial_session_was_empty = True + + if self.session_cookie in connection.cookies: + data = connection.cookies[self.session_cookie].encode("utf-8") + try: + data = self.signer.unsign(data, max_age=self.max_age) + scope["session"] = json.loads(b64decode(data)) + initial_session_was_empty = False + except BadSignature: + scope["session"] = {} + else: + scope["session"] = {} + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + if scope["session"]: + # We have session data to persist. + data = b64encode(json.dumps(scope["session"]).encode("utf-8")) + data = self.signer.sign(data) + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( + session_cookie=self.session_cookie, + data=data.decode("utf-8"), + path=self.path, + max_age=f"Max-Age={self.max_age}; " if self.max_age else "", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + elif not initial_session_was_empty: + # The session has been cleared. + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( + session_cookie=self.session_cookie, + data="null", + path=self.path, + expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + await send(message) + + await self.app(scope, receive, send_wrapper) diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.11/site-packages/starlette/middleware/trustedhost.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1c999e25f40929aacd526ad8b737bbd6394e1f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/trustedhost.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing + +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." + + +class TrustedHostMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] | None = None, + www_redirect: bool = True, + ) -> None: + if allowed_hosts is None: + allowed_hosts = ["*"] + + for pattern in allowed_hosts: + assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD + if pattern.startswith("*") and pattern != "*": + assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD + self.app = app + self.allowed_hosts = list(allowed_hosts) + self.allow_any = "*" in allowed_hosts + self.www_redirect = www_redirect + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.allow_any or scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + is_valid_host = False + found_www_redirect = False + for pattern in self.allowed_hosts: + if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): + is_valid_host = True + break + elif "www." + host == pattern: + found_www_redirect = True + + if is_valid_host: + await self.app(scope, receive, send) + else: + response: Response + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc="www." + url.netloc) + response = RedirectResponse(url=str(redirect_url)) + else: + response = PlainTextResponse("Invalid host header", status_code=400) + await response(scope, receive, send) diff --git a/.venv/lib/python3.11/site-packages/starlette/middleware/wsgi.py b/.venv/lib/python3.11/site-packages/starlette/middleware/wsgi.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0a3fae6c176485e506b498200ac22c499c6c50 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/middleware/wsgi.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import io +import math +import sys +import typing +import warnings + +import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream + +from starlette.types import Receive, Scope, Send + +warnings.warn( + "starlette.middleware.wsgi is deprecated and will be removed in a future release. " + "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.", + DeprecationWarning, +) + + +def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: + """ + Builds a scope and request body into a WSGI environ object. + """ + + script_name = scope.get("root_path", "").encode("utf8").decode("latin1") + path_info = scope["path"].encode("utf8").decode("latin1") + if path_info.startswith(script_name): + path_info = path_info[len(script_name) :] + + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": script_name, + "PATH_INFO": path_info, + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": io.BytesIO(body), + "wsgi.errors": sys.stdout, + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + # Get server name and port - required in WSGI, not in ASGI + server = scope.get("server") or ("localhost", 80) + environ["SERVER_NAME"] = server[0] + environ["SERVER_PORT"] = server[1] + + # Get client IP address + if scope.get("client"): + environ["REMOTE_ADDR"] = scope["client"][0] + + # Go through headers and make them into environ entries + for name, value in scope.get("headers", []): + name = name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = f"HTTP_{name}".upper().replace("-", "_") + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in + # case + value = value.decode("latin1") + if corrected_name in environ: + value = environ[corrected_name] + "," + value + environ[corrected_name] = value + return environ + + +class WSGIMiddleware: + def __init__(self, app: typing.Callable[..., typing.Any]) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + responder = WSGIResponder(self.app, scope) + await responder(receive, send) + + +class WSGIResponder: + stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + + def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: + self.app = app + self.scope = scope + self.status = None + self.response_headers = None + self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) + self.response_started = False + self.exc_info: typing.Any = None + + async def __call__(self, receive: Receive, send: Send) -> None: + body = b"" + more_body = True + while more_body: + message = await receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + environ = build_environ(self.scope, body) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) + + async def sender(self, send: Send) -> None: + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) + + def start_response( + self, + status: str, + response_headers: list[tuple[str, str]], + exc_info: typing.Any = None, + ) -> None: + self.exc_info = exc_info + if not self.response_started: # pragma: no branch + self.response_started = True + status_code_string, _ = status.split(" ", 1) + status_code = int(status_code_string) + headers = [ + (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) + for name, value in response_headers + ] + anyio.from_thread.run( + self.stream_send.send, + { + "type": "http.response.start", + "status": status_code, + "headers": headers, + }, + ) + + def wsgi( + self, + environ: dict[str, typing.Any], + start_response: typing.Callable[..., typing.Any], + ) -> None: + for chunk in self.app(environ, start_response): + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, + ) + + anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) diff --git a/.venv/lib/python3.11/site-packages/starlette/routing.py b/.venv/lib/python3.11/site-packages/starlette/routing.py new file mode 100644 index 0000000000000000000000000000000000000000..74daeadb0beec0ca243bf8a52ee97b7b3aad8716 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/routing.py @@ -0,0 +1,875 @@ +from __future__ import annotations + +import contextlib +import functools +import inspect +import re +import traceback +import types +import typing +import warnings +from contextlib import asynccontextmanager +from enum import Enum + +from starlette._exception_handler import wrap_app_handling_exceptions +from starlette._utils import get_route_path, is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.convertors import CONVERTOR_TYPES, Convertor +from starlette.datastructures import URL, Headers, URLPath +from starlette.exceptions import HTTPException +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send +from starlette.websockets import WebSocket, WebSocketClose + + +class NoMatchFound(Exception): + """ + Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)` + if no matching route exists. + """ + + def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None: + params = ", ".join(list(path_params.keys())) + super().__init__(f'No route exists for name "{name}" and params "{params}".') + + +class Match(Enum): + NONE = 0 + PARTIAL = 1 + FULL = 2 + + +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover + """ + Correctly determines if an object is a coroutine function, + including those wrapped in functools.partial objects. + """ + warnings.warn( + "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.", + DeprecationWarning, + ) + while isinstance(obj, functools.partial): + obj = obj.func + return inspect.iscoroutinefunction(obj) + + +def request_response( + func: typing.Callable[[Request], typing.Awaitable[Response] | Response], +) -> ASGIApp: + """ + Takes a function or coroutine `func(request) -> response`, + and returns an ASGI application. + """ + f: typing.Callable[[Request], typing.Awaitable[Response]] = ( + func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore + ) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive, send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = await f(request) + await response(scope, receive, send) + + await wrap_app_handling_exceptions(app, request)(scope, receive, send) + + return app + + +def websocket_session( + func: typing.Callable[[WebSocket], typing.Awaitable[None]], +) -> ASGIApp: + """ + Takes a coroutine `func(session)`, and returns an ASGI application. + """ + # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + session = WebSocket(scope, receive=receive, send=send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await func(session) + + await wrap_app_handling_exceptions(app, session)(scope, receive, send) + + return app + + +def get_name(endpoint: typing.Callable[..., typing.Any]) -> str: + return getattr(endpoint, "__name__", endpoint.__class__.__name__) + + +def replace_params( + path: str, + param_convertors: dict[str, Convertor[typing.Any]], + path_params: dict[str, str], +) -> tuple[str, dict[str, str]]: + for key, value in list(path_params.items()): + if "{" + key + "}" in path: + convertor = param_convertors[key] + value = convertor.to_string(value) + path = path.replace("{" + key + "}", value) + path_params.pop(key) + return path, path_params + + +# Match parameters in URL paths, eg. '{param}', and '{param:int}' +PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") + + +def compile_path( + path: str, +) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]: + """ + Given a path string, like: "/{username:str}", + or a host string, like: "{subdomain}.mydomain.org", return a three-tuple + of (regex, format, {param_name:convertor}). + + regex: "/(?P[^/]+)" + format: "/{username}" + convertors: {"username": StringConvertor()} + """ + is_host = not path.startswith("/") + + path_regex = "^" + path_format = "" + duplicated_params = set() + + idx = 0 + param_convertors = {} + for match in PARAM_REGEX.finditer(path): + param_name, convertor_type = match.groups("str") + convertor_type = convertor_type.lstrip(":") + assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'" + convertor = CONVERTOR_TYPES[convertor_type] + + path_regex += re.escape(path[idx : match.start()]) + path_regex += f"(?P<{param_name}>{convertor.regex})" + + path_format += path[idx : match.start()] + path_format += "{%s}" % param_name + + if param_name in param_convertors: + duplicated_params.add(param_name) + + param_convertors[param_name] = convertor + + idx = match.end() + + if duplicated_params: + names = ", ".join(sorted(duplicated_params)) + ending = "s" if len(duplicated_params) > 1 else "" + raise ValueError(f"Duplicated param name{ending} {names} at path {path}") + + if is_host: + # Align with `Host.matches()` behavior, which ignores port. + hostname = path[idx:].split(":")[0] + path_regex += re.escape(hostname) + "$" + else: + path_regex += re.escape(path[idx:]) + "$" + + path_format += path[idx:] + + return re.compile(path_regex), path_format, param_convertors + + +class BaseRoute: + def matches(self, scope: Scope) -> tuple[Match, Scope]: + raise NotImplementedError() # pragma: no cover + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + raise NotImplementedError() # pragma: no cover + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + raise NotImplementedError() # pragma: no cover + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + A route may be used in isolation as a stand-alone ASGI app. + This is a somewhat contrived case, as they'll almost always be used + within a Router, but could be useful for some tooling and minimal apps. + """ + match, child_scope = self.matches(scope) + if match == Match.NONE: + if scope["type"] == "http": + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + elif scope["type"] == "websocket": # pragma: no branch + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return + + scope.update(child_scope) + await self.handle(scope, receive, send) + + +class Route(BaseRoute): + def __init__( + self, + path: str, + endpoint: typing.Callable[..., typing.Any], + *, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + assert path.startswith("/"), "Routed paths must start with '/'" + self.path = path + self.endpoint = endpoint + self.name = get_name(endpoint) if name is None else name + self.include_in_schema = include_in_schema + + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): + # Endpoint is function or method. Treat it as `func(request) -> response`. + self.app = request_response(endpoint) + if methods is None: + methods = ["GET"] + else: + # Endpoint is a class. Treat it as ASGI. + self.app = endpoint + + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(self.app, *args, **kwargs) + + if methods is None: + self.methods = None + else: + self.methods = {method.upper() for method in methods} + if "GET" in self.methods: + self.methods.add("HEAD") + + self.path_regex, self.path_format, self.param_convertors = compile_path(path) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] + if scope["type"] == "http": + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"endpoint": self.endpoint, "path_params": path_params} + if self.methods and scope["method"] not in self.methods: + return Match.PARTIAL, child_scope + else: + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + seen_params = set(path_params.keys()) + expected_params = set(self.param_convertors.keys()) + + if name != self.name or seen_params != expected_params: + raise NoMatchFound(name, path_params) + + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + assert not remaining_params + return URLPath(path=path, protocol="http") + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.methods and scope["method"] not in self.methods: + headers = {"Allow": ", ".join(self.methods)} + if "app" in scope: + raise HTTPException(status_code=405, headers=headers) + else: + response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) + await response(scope, receive, send) + else: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Route) + and self.path == other.path + and self.endpoint == other.endpoint + and self.methods == other.methods + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + methods = sorted(self.methods or []) + path, name = self.path, self.name + return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})" + + +class WebSocketRoute(BaseRoute): + def __init__( + self, + path: str, + endpoint: typing.Callable[..., typing.Any], + *, + name: str | None = None, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + assert path.startswith("/"), "Routed paths must start with '/'" + self.path = path + self.endpoint = endpoint + self.name = get_name(endpoint) if name is None else name + + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): + # Endpoint is function or method. Treat it as `func(websocket)`. + self.app = websocket_session(endpoint) + else: + # Endpoint is a class. Treat it as ASGI. + self.app = endpoint + + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(self.app, *args, **kwargs) + + self.path_regex, self.path_format, self.param_convertors = compile_path(path) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] + if scope["type"] == "websocket": + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"endpoint": self.endpoint, "path_params": path_params} + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + seen_params = set(path_params.keys()) + expected_params = set(self.param_convertors.keys()) + + if name != self.name or seen_params != expected_params: + raise NoMatchFound(name, path_params) + + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + assert not remaining_params + return URLPath(path=path, protocol="websocket") + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})" + + +class Mount(BaseRoute): + def __init__( + self, + path: str, + app: ASGIApp | None = None, + routes: typing.Sequence[BaseRoute] | None = None, + name: str | None = None, + *, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + assert path == "" or path.startswith("/"), "Routed paths must start with '/'" + assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified" + self.path = path.rstrip("/") + if app is not None: + self._base_app: ASGIApp = app + else: + self._base_app = Router(routes=routes) + self.app = self._base_app + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(self.app, *args, **kwargs) + self.name = name + self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}") + + @property + def routes(self) -> list[BaseRoute]: + return getattr(self._base_app, "routes", []) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] + if scope["type"] in ("http", "websocket"): # pragma: no branch + root_path = scope.get("root_path", "") + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + remaining_path = "/" + matched_params.pop("path") + matched_path = route_path[: -len(remaining_path)] + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = { + "path_params": path_params, + # app_root_path will only be set at the top level scope, + # initialized with the (optional) value of a root_path + # set above/before Starlette. And even though any + # mount will have its own child scope with its own respective + # root_path, the app_root_path will always be available in all + # the child scopes with the same top level value because it's + # set only once here with a default, any other child scope will + # just inherit that app_root_path default value stored in the + # scope. All this is needed to support Request.url_for(), as it + # uses the app_root_path to build the URL path. + "app_root_path": scope.get("app_root_path", root_path), + "root_path": root_path + matched_path, + "endpoint": self.app, + } + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + if self.name is not None and name == self.name and "path" in path_params: + # 'name' matches "". + path_params["path"] = path_params["path"].lstrip("/") + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + if not remaining_params: + return URLPath(path=path) + elif self.name is None or name.startswith(self.name + ":"): + if self.name is None: + # No mount name. + remaining_name = name + else: + # 'name' matches ":". + remaining_name = name[len(self.name) + 1 :] + path_kwarg = path_params.get("path") + path_params["path"] = "" + path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + if path_kwarg is not None: + remaining_params["path"] = path_kwarg + for route in self.routes or []: + try: + url = route.url_path_for(remaining_name, **remaining_params) + return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol) + except NoMatchFound: + pass + raise NoMatchFound(name, path_params) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Mount) and self.path == other.path and self.app == other.app + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + name = self.name or "" + return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})" + + +class Host(BaseRoute): + def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None: + assert not host.startswith("/"), "Host must not start with '/'" + self.host = host + self.app = app + self.name = name + self.host_regex, self.host_format, self.param_convertors = compile_path(host) + + @property + def routes(self) -> list[BaseRoute]: + return getattr(self.app, "routes", []) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + if scope["type"] in ("http", "websocket"): # pragma:no branch + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + match = self.host_regex.match(host) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"path_params": path_params, "endpoint": self.app} + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + if self.name is not None and name == self.name and "path" in path_params: + # 'name' matches "". + path = path_params.pop("path") + host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) + if not remaining_params: + return URLPath(path=path, host=host) + elif self.name is None or name.startswith(self.name + ":"): + if self.name is None: + # No mount name. + remaining_name = name + else: + # 'name' matches ":". + remaining_name = name[len(self.name) + 1 :] + host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) + for route in self.routes or []: + try: + url = route.url_path_for(remaining_name, **remaining_params) + return URLPath(path=str(url), protocol=url.protocol, host=host) + except NoMatchFound: + pass + raise NoMatchFound(name, path_params) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Host) and self.host == other.host and self.app == other.app + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + name = self.name or "" + return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})" + + +_T = typing.TypeVar("_T") + + +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]], +) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]: + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + +class _DefaultLifespan: + def __init__(self, router: Router): + self._router = router + + async def __aenter__(self) -> None: + await self._router.startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: + return self + + +class Router: + def __init__( + self, + routes: typing.Sequence[BaseRoute] | None = None, + redirect_slashes: bool = True, + default: ASGIApp | None = None, + on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + # the generic to Lifespan[AppType] is the type of the top level application + # which the router cannot know statically, so we use typing.Any + lifespan: Lifespan[typing.Any] | None = None, + *, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + self.routes = [] if routes is None else list(routes) + self.redirect_slashes = redirect_slashes + self.default = self.not_found if default is None else default + self.on_startup = [] if on_startup is None else list(on_startup) + self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) + + if on_startup or on_shutdown: + warnings.warn( + "The on_startup and on_shutdown parameters are deprecated, and they " + "will be removed on version 1.0. Use the lifespan parameter instead. " + "See more about it on https://www.starlette.io/lifespan/.", + DeprecationWarning, + ) + if lifespan: + warnings.warn( + "The `lifespan` parameter cannot be used with `on_startup` or " + "`on_shutdown`. Both `on_startup` and `on_shutdown` will be " + "ignored." + ) + + if lifespan is None: + self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self) + + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = asynccontextmanager( + lifespan, + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, + ) + else: + self.lifespan_context = lifespan + + self.middleware_stack = self.app + if middleware: + for cls, args, kwargs in reversed(middleware): + self.middleware_stack = cls(self.middleware_stack, *args, **kwargs) + + async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "websocket": + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return + + # If we're running inside a starlette application then raise an + # exception, so that the configurable exception handler can deal with + # returning the response. For plain ASGI apps, just return the response. + if "app" in scope: + raise HTTPException(status_code=404) + else: + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + for route in self.routes: + try: + return route.url_path_for(name, **path_params) + except NoMatchFound: + pass + raise NoMatchFound(name, path_params) + + async def startup(self) -> None: + """ + Run any `.on_startup` event handlers. + """ + for handler in self.on_startup: + if is_async_callable(handler): + await handler() + else: + handler() + + async def shutdown(self) -> None: + """ + Run any `.on_shutdown` event handlers. + """ + for handler in self.on_shutdown: + if is_async_callable(handler): + await handler() + else: + handler() + + async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + Handle ASGI lifespan messages, which allows us to manage application + startup and shutdown events. + """ + started = False + app: typing.Any = scope.get("app") + await receive() + try: + async with self.lifespan_context(app) as maybe_state: + if maybe_state is not None: + if "state" not in scope: + raise RuntimeError('The server does not support "state" in the lifespan scope.') + scope["state"].update(maybe_state) + await send({"type": "lifespan.startup.complete"}) + started = True + await receive() + except BaseException: + exc_text = traceback.format_exc() + if started: + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + else: + await send({"type": "lifespan.startup.failed", "message": exc_text}) + raise + else: + await send({"type": "lifespan.shutdown.complete"}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + The main entry point to the Router class. + """ + await self.middleware_stack(scope, receive, send) + + async def app(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] in ("http", "websocket", "lifespan") + + if "router" not in scope: + scope["router"] = self + + if scope["type"] == "lifespan": + await self.lifespan(scope, receive, send) + return + + partial = None + + for route in self.routes: + # Determine if any route matches the incoming scope, + # and hand over to the matching route if found. + match, child_scope = route.matches(scope) + if match == Match.FULL: + scope.update(child_scope) + await route.handle(scope, receive, send) + return + elif match == Match.PARTIAL and partial is None: + partial = route + partial_scope = child_scope + + if partial is not None: + #  Handle partial matches. These are cases where an endpoint is + # able to handle the request, but is not a preferred option. + # We use this in particular to deal with "405 Method Not Allowed". + scope.update(partial_scope) + await partial.handle(scope, receive, send) + return + + route_path = get_route_path(scope) + if scope["type"] == "http" and self.redirect_slashes and route_path != "/": + redirect_scope = dict(scope) + if route_path.endswith("/"): + redirect_scope["path"] = redirect_scope["path"].rstrip("/") + else: + redirect_scope["path"] = redirect_scope["path"] + "/" + + for route in self.routes: + match, child_scope = route.matches(redirect_scope) + if match != Match.NONE: + redirect_url = URL(scope=redirect_scope) + response = RedirectResponse(url=str(redirect_url)) + await response(scope, receive, send) + return + + await self.default(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Router) and self.routes == other.routes + + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover + route = Mount(path, app=app, name=name) + self.routes.append(route) + + def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover + route = Host(host, app=app, name=name) + self.routes.append(route) + + def add_route( + self, + path: str, + endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response], + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> None: # pragma: no cover + route = Route( + path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + self.routes.append(route) + + def add_websocket_route( + self, + path: str, + endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], + name: str | None = None, + ) -> None: # pragma: no cover + route = WebSocketRoute(path, endpoint=endpoint, name=name) + self.routes.append(route) + + def route( + self, + path: str, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [Route(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `route` decorator is deprecated, and will be removed in version 1.0.0." + "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + return func + + return decorator + + def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [WebSocketRoute(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " + "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_websocket_route(path, func, name=name) + return func + + return decorator + + def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover + assert event_type in ("startup", "shutdown") + + if event_type == "startup": + self.on_startup.append(func) + else: + self.on_shutdown.append(func) + + def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + warnings.warn( + "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/lifespan/ for recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_event_handler(event_type, func) + return func + + return decorator diff --git a/.venv/lib/python3.11/site-packages/starlette/status.py b/.venv/lib/python3.11/site-packages/starlette/status.py new file mode 100644 index 0000000000000000000000000000000000000000..54c1fb7d0df128bfe16fb4cb03cdf23b8225af94 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/status.py @@ -0,0 +1,95 @@ +""" +HTTP codes +See HTTP Status Code Registry: +https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml + +And RFC 2324 - https://tools.ietf.org/html/rfc2324 +""" + +from __future__ import annotations + +HTTP_100_CONTINUE = 100 +HTTP_101_SWITCHING_PROTOCOLS = 101 +HTTP_102_PROCESSING = 102 +HTTP_103_EARLY_HINTS = 103 +HTTP_200_OK = 200 +HTTP_201_CREATED = 201 +HTTP_202_ACCEPTED = 202 +HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203 +HTTP_204_NO_CONTENT = 204 +HTTP_205_RESET_CONTENT = 205 +HTTP_206_PARTIAL_CONTENT = 206 +HTTP_207_MULTI_STATUS = 207 +HTTP_208_ALREADY_REPORTED = 208 +HTTP_226_IM_USED = 226 +HTTP_300_MULTIPLE_CHOICES = 300 +HTTP_301_MOVED_PERMANENTLY = 301 +HTTP_302_FOUND = 302 +HTTP_303_SEE_OTHER = 303 +HTTP_304_NOT_MODIFIED = 304 +HTTP_305_USE_PROXY = 305 +HTTP_306_RESERVED = 306 +HTTP_307_TEMPORARY_REDIRECT = 307 +HTTP_308_PERMANENT_REDIRECT = 308 +HTTP_400_BAD_REQUEST = 400 +HTTP_401_UNAUTHORIZED = 401 +HTTP_402_PAYMENT_REQUIRED = 402 +HTTP_403_FORBIDDEN = 403 +HTTP_404_NOT_FOUND = 404 +HTTP_405_METHOD_NOT_ALLOWED = 405 +HTTP_406_NOT_ACCEPTABLE = 406 +HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407 +HTTP_408_REQUEST_TIMEOUT = 408 +HTTP_409_CONFLICT = 409 +HTTP_410_GONE = 410 +HTTP_411_LENGTH_REQUIRED = 411 +HTTP_412_PRECONDITION_FAILED = 412 +HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413 +HTTP_414_REQUEST_URI_TOO_LONG = 414 +HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415 +HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416 +HTTP_417_EXPECTATION_FAILED = 417 +HTTP_418_IM_A_TEAPOT = 418 +HTTP_421_MISDIRECTED_REQUEST = 421 +HTTP_422_UNPROCESSABLE_ENTITY = 422 +HTTP_423_LOCKED = 423 +HTTP_424_FAILED_DEPENDENCY = 424 +HTTP_425_TOO_EARLY = 425 +HTTP_426_UPGRADE_REQUIRED = 426 +HTTP_428_PRECONDITION_REQUIRED = 428 +HTTP_429_TOO_MANY_REQUESTS = 429 +HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431 +HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451 +HTTP_500_INTERNAL_SERVER_ERROR = 500 +HTTP_501_NOT_IMPLEMENTED = 501 +HTTP_502_BAD_GATEWAY = 502 +HTTP_503_SERVICE_UNAVAILABLE = 503 +HTTP_504_GATEWAY_TIMEOUT = 504 +HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 +HTTP_506_VARIANT_ALSO_NEGOTIATES = 506 +HTTP_507_INSUFFICIENT_STORAGE = 507 +HTTP_508_LOOP_DETECTED = 508 +HTTP_510_NOT_EXTENDED = 510 +HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 + + +""" +WebSocket codes +https://www.iana.org/assignments/websocket/websocket.xml#close-code-number +https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent +""" +WS_1000_NORMAL_CLOSURE = 1000 +WS_1001_GOING_AWAY = 1001 +WS_1002_PROTOCOL_ERROR = 1002 +WS_1003_UNSUPPORTED_DATA = 1003 +WS_1005_NO_STATUS_RCVD = 1005 +WS_1006_ABNORMAL_CLOSURE = 1006 +WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007 +WS_1008_POLICY_VIOLATION = 1008 +WS_1009_MESSAGE_TOO_BIG = 1009 +WS_1010_MANDATORY_EXT = 1010 +WS_1011_INTERNAL_ERROR = 1011 +WS_1012_SERVICE_RESTART = 1012 +WS_1013_TRY_AGAIN_LATER = 1013 +WS_1014_BAD_GATEWAY = 1014 +WS_1015_TLS_HANDSHAKE = 1015 diff --git a/.venv/lib/python3.11/site-packages/starlette/websockets.py b/.venv/lib/python3.11/site-packages/starlette/websockets.py new file mode 100644 index 0000000000000000000000000000000000000000..6b46f4eaef63a97d9263f0add4da1ac657b86ced --- /dev/null +++ b/.venv/lib/python3.11/site-packages/starlette/websockets.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import enum +import json +import typing + +from starlette.requests import HTTPConnection +from starlette.responses import Response +from starlette.types import Message, Receive, Scope, Send + + +class WebSocketState(enum.Enum): + CONNECTING = 0 + CONNECTED = 1 + DISCONNECTED = 2 + RESPONSE = 3 + + +class WebSocketDisconnect(Exception): + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocket(HTTPConnection): + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + super().__init__(scope) + assert scope["type"] == "websocket" + self._receive = receive + self._send = send + self.client_state = WebSocketState.CONNECTING + self.application_state = WebSocketState.CONNECTING + + async def receive(self) -> Message: + """ + Receive ASGI websocket messages, ensuring valid state transitions. + """ + if self.client_state == WebSocketState.CONNECTING: + message = await self._receive() + message_type = message["type"] + if message_type != "websocket.connect": + raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') + self.client_state = WebSocketState.CONNECTED + return message + elif self.client_state == WebSocketState.CONNECTED: + message = await self._receive() + message_type = message["type"] + if message_type not in {"websocket.receive", "websocket.disconnect"}: + raise RuntimeError( + f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' + ) + if message_type == "websocket.disconnect": + self.client_state = WebSocketState.DISCONNECTED + return message + else: + raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') + + async def send(self, message: Message) -> None: + """ + Send ASGI websocket messages, ensuring valid state transitions. + """ + if self.application_state == WebSocketState.CONNECTING: + message_type = message["type"] + if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: + raise RuntimeError( + 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' + f"but got {message_type!r}" + ) + if message_type == "websocket.close": + self.application_state = WebSocketState.DISCONNECTED + elif message_type == "websocket.http.response.start": + self.application_state = WebSocketState.RESPONSE + else: + self.application_state = WebSocketState.CONNECTED + await self._send(message) + elif self.application_state == WebSocketState.CONNECTED: + message_type = message["type"] + if message_type not in {"websocket.send", "websocket.close"}: + raise RuntimeError( + f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' + ) + if message_type == "websocket.close": + self.application_state = WebSocketState.DISCONNECTED + try: + await self._send(message) + except OSError: + self.application_state = WebSocketState.DISCONNECTED + raise WebSocketDisconnect(code=1006) + elif self.application_state == WebSocketState.RESPONSE: + message_type = message["type"] + if message_type != "websocket.http.response.body": + raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}') + if not message.get("more_body", False): + self.application_state = WebSocketState.DISCONNECTED + await self._send(message) + else: + raise RuntimeError('Cannot call "send" once a close message has been sent.') + + async def accept( + self, + subprotocol: str | None = None, + headers: typing.Iterable[tuple[bytes, bytes]] | None = None, + ) -> None: + headers = headers or [] + + if self.client_state == WebSocketState.CONNECTING: # pragma: no branch + # If we haven't yet seen the 'connect' message, then wait for it first. + await self.receive() + await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}) + + def _raise_on_disconnect(self, message: Message) -> None: + if message["type"] == "websocket.disconnect": + raise WebSocketDisconnect(message["code"], message.get("reason")) + + async def receive_text(self) -> str: + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + message = await self.receive() + self._raise_on_disconnect(message) + return typing.cast(str, message["text"]) + + async def receive_bytes(self) -> bytes: + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + message = await self.receive() + self._raise_on_disconnect(message) + return typing.cast(bytes, message["bytes"]) + + async def receive_json(self, mode: str = "text") -> typing.Any: + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + message = await self.receive() + self._raise_on_disconnect(message) + + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + async def iter_text(self) -> typing.AsyncIterator[str]: + try: + while True: + yield await self.receive_text() + except WebSocketDisconnect: + pass + + async def iter_bytes(self) -> typing.AsyncIterator[bytes]: + try: + while True: + yield await self.receive_bytes() + except WebSocketDisconnect: + pass + + async def iter_json(self) -> typing.AsyncIterator[typing.Any]: + try: + while True: + yield await self.receive_json() + except WebSocketDisconnect: + pass + + async def send_text(self, data: str) -> None: + await self.send({"type": "websocket.send", "text": data}) + + async def send_bytes(self, data: bytes) -> None: + await self.send({"type": "websocket.send", "bytes": data}) + + async def send_json(self, data: typing.Any, mode: str = "text") -> None: + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') + text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) + if mode == "text": + await self.send({"type": "websocket.send", "text": text}) + else: + await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + await self.send({"type": "websocket.close", "code": code, "reason": reason or ""}) + + async def send_denial_response(self, response: Response) -> None: + if "websocket.http.response" in self.scope.get("extensions", {}): + await response(self.scope, self.receive, self.send) + else: + raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") + + +class WebSocketClose: + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.close", "code": self.code, "reason": self.reason}) diff --git a/.venv/lib/python3.11/site-packages/websockets/__main__.py b/.venv/lib/python3.11/site-packages/websockets/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..8647481d07cc02985ede265ba2918bd31aaa8ed2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/__main__.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import argparse +import os +import signal +import sys +import threading + + +try: + import readline # noqa: F401 +except ImportError: # Windows has no `readline` normally + pass + +from .sync.client import ClientConnection, connect +from .version import version as websockets_version + + +if sys.platform == "win32": + + def win_enable_vt100() -> None: + """ + Enable VT-100 for console output on Windows. + + See also https://github.com/python/cpython/issues/73245. + + """ + import ctypes + + STD_OUTPUT_HANDLE = ctypes.c_uint(-11) + INVALID_HANDLE_VALUE = ctypes.c_uint(-1) + ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004 + + handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) + if handle == INVALID_HANDLE_VALUE: + raise RuntimeError("unable to obtain stdout handle") + + cur_mode = ctypes.c_uint() + if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: + raise RuntimeError("unable to query current console mode") + + # ctypes ints lack support for the required bit-OR operation. + # Temporarily convert to Py int, do the OR and convert back. + py_int_mode = int.from_bytes(cur_mode, sys.byteorder) + new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) + + if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: + raise RuntimeError("unable to set console mode") + + +def print_during_input(string: str) -> None: + sys.stdout.write( + # Save cursor position + "\N{ESC}7" + # Add a new line + "\N{LINE FEED}" + # Move cursor up + "\N{ESC}[A" + # Insert blank line, scroll last line down + "\N{ESC}[L" + # Print string in the inserted blank line + f"{string}\N{LINE FEED}" + # Restore cursor position + "\N{ESC}8" + # Move cursor down + "\N{ESC}[B" + ) + sys.stdout.flush() + + +def print_over_input(string: str) -> None: + sys.stdout.write( + # Move cursor to beginning of line + "\N{CARRIAGE RETURN}" + # Delete current line + "\N{ESC}[K" + # Print string + f"{string}\N{LINE FEED}" + ) + sys.stdout.flush() + + +def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None: + for message in websocket: + if isinstance(message, str): + print_during_input("< " + message) + else: + print_during_input("< (binary) " + message.hex()) + if not stop.is_set(): + # When the server closes the connection, raise KeyboardInterrupt + # in the main thread to exit the program. + if sys.platform == "win32": + ctrl_c = signal.CTRL_C_EVENT + else: + ctrl_c = signal.SIGINT + os.kill(os.getpid(), ctrl_c) + + +def main() -> None: + # Parse command line arguments. + parser = argparse.ArgumentParser( + prog="python -m websockets", + description="Interactive WebSocket client.", + add_help=False, + ) + group = parser.add_mutually_exclusive_group() + group.add_argument("--version", action="store_true") + group.add_argument("uri", metavar="", nargs="?") + args = parser.parse_args() + + if args.version: + print(f"websockets {websockets_version}") + return + + if args.uri is None: + parser.error("the following arguments are required: ") + + # If we're on Windows, enable VT100 terminal support. + if sys.platform == "win32": + try: + win_enable_vt100() + except RuntimeError as exc: + sys.stderr.write( + f"Unable to set terminal to VT100 mode. This is only " + f"supported since Win10 anniversary update. Expect " + f"weird symbols on the terminal.\nError: {exc}\n" + ) + sys.stderr.flush() + + try: + websocket = connect(args.uri) + except Exception as exc: + print(f"Failed to connect to {args.uri}: {exc}.") + sys.exit(1) + else: + print(f"Connected to {args.uri}.") + + stop = threading.Event() + + # Start the thread that reads messages from the connection. + thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop)) + thread.start() + + # Read from stdin in the main thread in order to receive signals. + try: + while True: + # Since there's no size limit, put_nowait is identical to put. + message = input("> ") + websocket.send(message) + except (KeyboardInterrupt, EOFError): # ^C, ^D + stop.set() + websocket.close() + print_over_input("Connection closed.") + + thread.join() + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/websockets/auth.py b/.venv/lib/python3.11/site-packages/websockets/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..15b70a3727b2eb3202fc87173ad2fc8b742cf72c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/auth.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import warnings + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.auth import * + from .legacy.auth import __all__ # noqa: F401 + + +warnings.warn( # deprecated in 14.0 - 2024-11-09 + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/.venv/lib/python3.11/site-packages/websockets/client.py b/.venv/lib/python3.11/site-packages/websockets/client.py new file mode 100644 index 0000000000000000000000000000000000000000..37e2a8b3a5a7d3fe46d64a622009729e97ba9521 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/client.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import os +import random +import warnings +from collections.abc import Generator, Sequence +from typing import Any + +from .datastructures import Headers, MultipleValuesError +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidMessage, + InvalidStatus, + InvalidUpgrade, + NegotiationError, +) +from .extensions import ClientExtensionFactory, Extension +from .headers import ( + build_authorization_basic, + build_extension, + build_host, + build_subprotocol, + parse_connection, + parse_extension, + parse_subprotocol, + parse_upgrade, +) +from .http11 import Request, Response +from .imports import lazy_import +from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State +from .typing import ( + ConnectionOption, + ExtensionHeader, + LoggerLike, + Origin, + Subprotocol, + UpgradeProtocol, +) +from .uri import WebSocketURI +from .utils import accept_key, generate_key + + +__all__ = ["ClientProtocol"] + + +class ClientProtocol(Protocol): + """ + Sans-I/O implementation of a WebSocket client connection. + + Args: + wsuri: URI of the WebSocket server, parsed + with :func:`~websockets.uri.parse_uri`. + origin: Value of the ``Origin`` header. This is useful when connecting + to a server that validates the ``Origin`` header to defend against + Cross-Site WebSocket Hijacking attacks. + extensions: List of supported extensions, in order in which they + should be tried. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; + :obj:`None` disables the limit. + logger: Logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + wsuri: WebSocketURI, + *, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + state: State = CONNECTING, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, + ) -> None: + super().__init__( + side=CLIENT, + state=state, + max_size=max_size, + logger=logger, + ) + self.wsuri = wsuri + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.key = generate_key() + + def connect(self) -> Request: + """ + Create a handshake request to open a connection. + + You must send the handshake request with :meth:`send_request`. + + You can modify it before sending it, for example to add HTTP headers. + + Returns: + WebSocket handshake request event to send to the server. + + """ + headers = Headers() + + headers["Host"] = build_host( + self.wsuri.host, self.wsuri.port, self.wsuri.secure + ) + + if self.wsuri.user_info: + headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info) + + if self.origin is not None: + headers["Origin"] = self.origin + + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = self.key + headers["Sec-WebSocket-Version"] = "13" + + if self.available_extensions is not None: + extensions_header = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in self.available_extensions + ] + ) + headers["Sec-WebSocket-Extensions"] = extensions_header + + if self.available_subprotocols is not None: + protocol_header = build_subprotocol(self.available_subprotocols) + headers["Sec-WebSocket-Protocol"] = protocol_header + + return Request(self.wsuri.resource_name, headers) + + def process_response(self, response: Response) -> None: + """ + Check a handshake response. + + Args: + request: WebSocket handshake response received from the server. + + Raises: + InvalidHandshake: If the handshake response is invalid. + + """ + + if response.status_code != 101: + raise InvalidStatus(response) + + headers = response.headers + + connection: list[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade( + "Connection", ", ".join(connection) if connection else None + ) + + upgrade: list[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + + try: + s_w_accept = headers["Sec-WebSocket-Accept"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None + + if s_w_accept != accept_key(self.key): + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) + + self.extensions = self.process_extensions(headers) + + self.subprotocol = self.process_subprotocol(headers) + + def process_extensions(self, headers: Headers) -> list[Extension]: + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + Check that each extension is supported, as well as its parameters. + + :rfc:`6455` leaves the rules up to the specification of each + extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + Args: + headers: WebSocket handshake response headers. + + Returns: + List of accepted extensions. + + Raises: + InvalidHandshake: To abort the handshake. + + """ + accepted_extensions: list[Extension] = [] + + extensions = headers.get_all("Sec-WebSocket-Extensions") + + if extensions: + if self.available_extensions is None: + raise NegotiationError("no extensions supported") + + parsed_extensions: list[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in extensions], [] + ) + + for name, response_params in parsed_extensions: + for extension_factory in self.available_extensions: + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + f"Unsupported extension: " + f"name = {name}, params = {response_params}" + ) + + return accepted_extensions + + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: + """ + Handle the Sec-WebSocket-Protocol HTTP response header. + + If provided, check that it contains exactly one supported subprotocol. + + Args: + headers: WebSocket handshake response headers. + + Returns: + Subprotocol, if one was selected. + + """ + subprotocol: Subprotocol | None = None + + subprotocols = headers.get_all("Sec-WebSocket-Protocol") + + if subprotocols: + if self.available_subprotocols is None: + raise NegotiationError("no subprotocols supported") + + parsed_subprotocols: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in subprotocols], [] + ) + + if len(parsed_subprotocols) > 1: + raise InvalidHeader( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_subprotocols)}", + ) + + subprotocol = parsed_subprotocols[0] + + if subprotocol not in self.available_subprotocols: + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + + return subprotocol + + def send_request(self, request: Request) -> None: + """ + Send a handshake request to the server. + + Args: + request: WebSocket handshake request event. + + """ + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + + self.writes.append(request.serialize()) + + def parse(self) -> Generator[None]: + if self.state is CONNECTING: + try: + response = yield from Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + ) + except Exception as exc: + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP response" + ) + self.handshake_exc.__cause__ = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield + + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("< HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + if response.body: + self.logger.debug("< [body] (%d bytes)", len(response.body)) + + try: + self.process_response(response) + except InvalidHandshake as exc: + response._exception = exc + self.events.append(response) + self.handshake_exc = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield + + assert self.state is CONNECTING + self.state = OPEN + self.events.append(response) + + yield from super().parse() + + +class ClientConnection(ClientProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( # deprecated in 11.0 - 2023-04-02 + "ClientConnection was renamed to ClientProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + + +BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) +BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) +BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) +BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) + + +def backoff( + initial_delay: float = BACKOFF_INITIAL_DELAY, + min_delay: float = BACKOFF_MIN_DELAY, + max_delay: float = BACKOFF_MAX_DELAY, + factor: float = BACKOFF_FACTOR, +) -> Generator[float]: + """ + Generate a series of backoff delays between reconnection attempts. + + Yields: + How many seconds to wait before retrying to connect. + + """ + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. + yield random.random() * initial_delay + delay = min_delay + while delay < max_delay: + yield delay + delay *= factor + while True: + yield max_delay + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + }, +) diff --git a/.venv/lib/python3.11/site-packages/websockets/connection.py b/.venv/lib/python3.11/site-packages/websockets/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..5e78e34479224d0332b165badd67a8933e0c73db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/connection.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import warnings + +from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 + + +warnings.warn( # deprecated in 11.0 - 2023-04-02 + "websockets.connection was renamed to websockets.protocol " + "and Connection was renamed to Protocol", + DeprecationWarning, +) diff --git a/.venv/lib/python3.11/site-packages/websockets/datastructures.py b/.venv/lib/python3.11/site-packages/websockets/datastructures.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dcbe9a8a9a0de44b457e3b813814d11f67445 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/datastructures.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Mapping, MutableMapping +from typing import Any, Protocol, Union + + +__all__ = [ + "Headers", + "HeadersLike", + "MultipleValuesError", +] + + +class MultipleValuesError(LookupError): + """ + Exception raised when :class:`Headers` has multiple values for a key. + + """ + + def __str__(self) -> str: + # Implement the same logic as KeyError_str in Objects/exceptions.c. + if len(self.args) == 1: + return repr(self.args[0]) + return super().__str__() + + +class Headers(MutableMapping[str, str]): + """ + Efficient data structure for manipulating HTTP headers. + + A :class:`list` of ``(name, values)`` is inefficient for lookups. + + A :class:`dict` doesn't suffice because header names are case-insensitive + and multiple occurrences of headers with the same name are possible. + + :class:`Headers` stores HTTP headers in a hybrid data structure to provide + efficient insertions and lookups while preserving the original data. + + In order to account for multiple values with minimal hassle, + :class:`Headers` follows this logic: + + - When getting a header with ``headers[name]``: + - if there's no value, :exc:`KeyError` is raised; + - if there's exactly one value, it's returned; + - if there's more than one value, :exc:`MultipleValuesError` is raised. + + - When setting a header with ``headers[name] = value``, the value is + appended to the list of values for that header. + + - When deleting a header with ``del headers[name]``, all values for that + header are removed (this is slow). + + Other methods for manipulating headers are consistent with this logic. + + As long as no header occurs multiple times, :class:`Headers` behaves like + :class:`dict`, except keys are lower-cased to provide case-insensitivity. + + Two methods support manipulating multiple values explicitly: + + - :meth:`get_all` returns a list of all values for a header; + - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. + + """ + + __slots__ = ["_dict", "_list"] + + # Like dict, Headers accepts an optional "mapping or iterable" argument. + def __init__(self, *args: HeadersLike, **kwargs: str) -> None: + self._dict: dict[str, list[str]] = {} + self._list: list[tuple[str, str]] = [] + self.update(*args, **kwargs) + + def __str__(self) -> str: + return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._list!r})" + + def copy(self) -> Headers: + copy = self.__class__() + copy._dict = self._dict.copy() + copy._list = self._list.copy() + return copy + + def serialize(self) -> bytes: + # Since headers only contain ASCII characters, we can keep this simple. + return str(self).encode() + + # Collection methods + + def __contains__(self, key: object) -> bool: + return isinstance(key, str) and key.lower() in self._dict + + def __iter__(self) -> Iterator[str]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + # MutableMapping methods + + def __getitem__(self, key: str) -> str: + value = self._dict[key.lower()] + if len(value) == 1: + return value[0] + else: + raise MultipleValuesError(key) + + def __setitem__(self, key: str, value: str) -> None: + self._dict.setdefault(key.lower(), []).append(value) + self._list.append((key, value)) + + def __delitem__(self, key: str) -> None: + key_lower = key.lower() + self._dict.__delitem__(key_lower) + # This is inefficient. Fortunately deleting HTTP headers is uncommon. + self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Headers): + return NotImplemented + return self._dict == other._dict + + def clear(self) -> None: + """ + Remove all headers. + + """ + self._dict = {} + self._list = [] + + def update(self, *args: HeadersLike, **kwargs: str) -> None: + """ + Update from a :class:`Headers` instance and/or keyword arguments. + + """ + args = tuple( + arg.raw_items() if isinstance(arg, Headers) else arg for arg in args + ) + super().update(*args, **kwargs) + + # Methods for handling multiple values + + def get_all(self, key: str) -> list[str]: + """ + Return the (possibly empty) list of all values for a header. + + Args: + key: Header name. + + """ + return self._dict.get(key.lower(), []) + + def raw_items(self) -> Iterator[tuple[str, str]]: + """ + Return an iterator of all values as ``(name, value)`` pairs. + + """ + return iter(self._list) + + +# copy of _typeshed.SupportsKeysAndGetItem. +class SupportsKeysAndGetItem(Protocol): # pragma: no cover + """ + Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods. + + """ + + def keys(self) -> Iterable[str]: ... + + def __getitem__(self, key: str) -> str: ... + + +# Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10. +HeadersLike = Union[ + Headers, + Mapping[str, str], + Iterable[tuple[str, str]], + SupportsKeysAndGetItem, +] +""" +Types accepted where :class:`Headers` is expected. + +In addition to :class:`Headers` itself, this includes dict-like types where both +keys and values are :class:`str`. + +""" diff --git a/.venv/lib/python3.11/site-packages/websockets/exceptions.py b/.venv/lib/python3.11/site-packages/websockets/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..73b24debfcab3e8a1c78125ad72a5610980317b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/exceptions.py @@ -0,0 +1,424 @@ +""" +:mod:`websockets.exceptions` defines the following hierarchy of exceptions. + +* :exc:`WebSocketException` + * :exc:`ConnectionClosed` + * :exc:`ConnectionClosedOK` + * :exc:`ConnectionClosedError` + * :exc:`InvalidURI` + * :exc:`InvalidHandshake` + * :exc:`SecurityError` + * :exc:`InvalidMessage` + * :exc:`InvalidStatus` + * :exc:`InvalidStatusCode` (legacy) + * :exc:`InvalidHeader` + * :exc:`InvalidHeaderFormat` + * :exc:`InvalidHeaderValue` + * :exc:`InvalidOrigin` + * :exc:`InvalidUpgrade` + * :exc:`NegotiationError` + * :exc:`DuplicateParameter` + * :exc:`InvalidParameterName` + * :exc:`InvalidParameterValue` + * :exc:`AbortHandshake` (legacy) + * :exc:`RedirectHandshake` (legacy) + * :exc:`ProtocolError` (Sans-I/O) + * :exc:`PayloadTooBig` (Sans-I/O) + * :exc:`InvalidState` (Sans-I/O) + * :exc:`ConcurrencyError` + +""" + +from __future__ import annotations + +import warnings + +from .imports import lazy_import + + +__all__ = [ + "WebSocketException", + "ConnectionClosed", + "ConnectionClosedOK", + "ConnectionClosedError", + "InvalidURI", + "InvalidHandshake", + "SecurityError", + "InvalidStatus", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidMessage", + "InvalidOrigin", + "InvalidUpgrade", + "NegotiationError", + "DuplicateParameter", + "InvalidParameterName", + "InvalidParameterValue", + "ProtocolError", + "PayloadTooBig", + "InvalidState", + "ConcurrencyError", +] + + +class WebSocketException(Exception): + """ + Base class for all exceptions defined by websockets. + + """ + + +class ConnectionClosed(WebSocketException): + """ + Raised when trying to interact with a closed connection. + + Attributes: + rcvd: If a close frame was received, its code and reason are available + in ``rcvd.code`` and ``rcvd.reason``. + sent: If a close frame was sent, its code and reason are available + in ``sent.code`` and ``sent.reason``. + rcvd_then_sent: If close frames were received and sent, this attribute + tells in which order this happened, from the perspective of this + side of the connection. + + """ + + def __init__( + self, + rcvd: frames.Close | None, + sent: frames.Close | None, + rcvd_then_sent: bool | None = None, + ) -> None: + self.rcvd = rcvd + self.sent = sent + self.rcvd_then_sent = rcvd_then_sent + assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None) + + def __str__(self) -> str: + if self.rcvd is None: + if self.sent is None: + return "no close frame received or sent" + else: + return f"sent {self.sent}; no close frame received" + else: + if self.sent is None: + return f"received {self.rcvd}; no close frame sent" + else: + if self.rcvd_then_sent: + return f"received {self.rcvd}; then sent {self.sent}" + else: + return f"sent {self.sent}; then received {self.rcvd}" + + # code and reason attributes are provided for backwards-compatibility + + @property + def code(self) -> int: + warnings.warn( # deprecated in 13.1 - 2024-09-21 + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code", + DeprecationWarning, + ) + if self.rcvd is None: + return frames.CloseCode.ABNORMAL_CLOSURE + return self.rcvd.code + + @property + def reason(self) -> str: + warnings.warn( # deprecated in 13.1 - 2024-09-21 + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason", + DeprecationWarning, + ) + if self.rcvd is None: + return "" + return self.rcvd.reason + + +class ConnectionClosedOK(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated properly. + + A close code with code 1000 (OK) or 1001 (going away) or without a code was + received and sent. + + """ + + +class ConnectionClosedError(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated with an error. + + A close frame with a code other than 1000 (OK) or 1001 (going away) was + received or sent, or the closing handshake didn't complete properly. + + """ + + +class InvalidURI(WebSocketException): + """ + Raised when connecting to a URI that isn't a valid WebSocket URI. + + """ + + def __init__(self, uri: str, msg: str) -> None: + self.uri = uri + self.msg = msg + + def __str__(self) -> str: + return f"{self.uri} isn't a valid URI: {self.msg}" + + +class InvalidHandshake(WebSocketException): + """ + Base class for exceptions raised when the opening handshake fails. + + """ + + +class SecurityError(InvalidHandshake): + """ + Raised when a handshake request or response breaks a security rule. + + Security limits can be configured with :doc:`environment variables + <../reference/variables>`. + + """ + + +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ + + +class InvalidStatus(InvalidHandshake): + """ + Raised when a handshake response rejects the WebSocket upgrade. + + """ + + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return ( + f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" + ) + + +class InvalidHeader(InvalidHandshake): + """ + Raised when an HTTP header doesn't have a valid format or value. + + """ + + def __init__(self, name: str, value: str | None = None) -> None: + self.name = name + self.value = value + + def __str__(self) -> str: + if self.value is None: + return f"missing {self.name} header" + elif self.value == "": + return f"empty {self.name} header" + else: + return f"invalid {self.name} header: {self.value}" + + +class InvalidHeaderFormat(InvalidHeader): + """ + Raised when an HTTP header cannot be parsed. + + The format of the header doesn't match the grammar for that header. + + """ + + def __init__(self, name: str, error: str, header: str, pos: int) -> None: + super().__init__(name, f"{error} at {pos} in {header}") + + +class InvalidHeaderValue(InvalidHeader): + """ + Raised when an HTTP header has a wrong value. + + The format of the header is correct but the value isn't acceptable. + + """ + + +class InvalidOrigin(InvalidHeader): + """ + Raised when the Origin header in a request isn't allowed. + + """ + + def __init__(self, origin: str | None) -> None: + super().__init__("Origin", origin) + + +class InvalidUpgrade(InvalidHeader): + """ + Raised when the Upgrade or Connection header isn't correct. + + """ + + +class NegotiationError(InvalidHandshake): + """ + Raised when negotiating an extension or a subprotocol fails. + + """ + + +class DuplicateParameter(NegotiationError): + """ + Raised when a parameter name is repeated in an extension header. + + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __str__(self) -> str: + return f"duplicate parameter: {self.name}" + + +class InvalidParameterName(NegotiationError): + """ + Raised when a parameter name in an extension header is invalid. + + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __str__(self) -> str: + return f"invalid parameter name: {self.name}" + + +class InvalidParameterValue(NegotiationError): + """ + Raised when a parameter value in an extension header is invalid. + + """ + + def __init__(self, name: str, value: str | None) -> None: + self.name = name + self.value = value + + def __str__(self) -> str: + if self.value is None: + return f"missing value for parameter {self.name}" + elif self.value == "": + return f"empty value for parameter {self.name}" + else: + return f"invalid value for parameter {self.name}: {self.value}" + + +class ProtocolError(WebSocketException): + """ + Raised when receiving or sending a frame that breaks the protocol. + + The Sans-I/O implementation raises this exception when: + + * receiving or sending a frame that contains invalid data; + * receiving or sending an invalid sequence of frames. + + """ + + +class PayloadTooBig(WebSocketException): + """ + Raised when parsing a frame with a payload that exceeds the maximum size. + + The Sans-I/O layer uses this exception internally. It doesn't bubble up to + the I/O layer. + + The :meth:`~websockets.extensions.Extension.decode` method of extensions + must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit. + + """ + + def __init__( + self, + size_or_message: int | None | str, + max_size: int | None = None, + cur_size: int | None = None, + ) -> None: + if isinstance(size_or_message, str): + assert max_size is None + assert cur_size is None + warnings.warn( # deprecated in 14.0 - 2024-11-09 + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + DeprecationWarning, + ) + self.message: str | None = size_or_message + else: + self.message = None + self.size: int | None = size_or_message + assert max_size is not None + self.max_size: int = max_size + self.cur_size: int | None = None + self.set_current_size(cur_size) + + def __str__(self) -> str: + if self.message is not None: + return self.message + else: + message = "frame " + if self.size is not None: + message += f"with {self.size} bytes " + if self.cur_size is not None: + message += f"after reading {self.cur_size} bytes " + message += f"exceeds limit of {self.max_size} bytes" + return message + + def set_current_size(self, cur_size: int | None) -> None: + assert self.cur_size is None + if cur_size is not None: + self.max_size += cur_size + self.cur_size = cur_size + + +class InvalidState(WebSocketException, AssertionError): + """ + Raised when sending a frame is forbidden in the current state. + + Specifically, the Sans-I/O layer raises this exception when: + + * sending a data frame to a connection in a state other + :attr:`~websockets.protocol.State.OPEN`; + * sending a control frame to a connection in a state other than + :attr:`~websockets.protocol.State.OPEN` or + :attr:`~websockets.protocol.State.CLOSING`. + + """ + + +class ConcurrencyError(WebSocketException, RuntimeError): + """ + Raised when receiving or sending messages concurrently. + + WebSocket is a connection-oriented protocol. Reads must be serialized; so + must be writes. However, reading and writing concurrently is possible. + + """ + + +# At the bottom to break import cycles created by type annotations. +from . import frames, http11 # noqa: E402 + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "AbortHandshake": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + }, +) diff --git a/.venv/lib/python3.11/site-packages/websockets/frames.py b/.venv/lib/python3.11/site-packages/websockets/frames.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0869d013389a8c1c0d4b48db85c8a1619b44ef --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/frames.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import dataclasses +import enum +import io +import os +import secrets +import struct +from collections.abc import Generator, Sequence +from typing import Callable, Union + +from .exceptions import PayloadTooBig, ProtocolError + + +try: + from .speedups import apply_mask +except ImportError: + from .utils import apply_mask + + +__all__ = [ + "Opcode", + "OP_CONT", + "OP_TEXT", + "OP_BINARY", + "OP_CLOSE", + "OP_PING", + "OP_PONG", + "DATA_OPCODES", + "CTRL_OPCODES", + "CloseCode", + "Frame", + "Close", +] + + +class Opcode(enum.IntEnum): + """Opcode values for WebSocket frames.""" + + CONT, TEXT, BINARY = 0x00, 0x01, 0x02 + CLOSE, PING, PONG = 0x08, 0x09, 0x0A + + +OP_CONT = Opcode.CONT +OP_TEXT = Opcode.TEXT +OP_BINARY = Opcode.BINARY +OP_CLOSE = Opcode.CLOSE +OP_PING = Opcode.PING +OP_PONG = Opcode.PONG + +DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY +CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG + + +class CloseCode(enum.IntEnum): + """Close code values for WebSocket close frames.""" + + NORMAL_CLOSURE = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 + # 1004 is reserved + NO_STATUS_RCVD = 1005 + ABNORMAL_CLOSURE = 1006 + INVALID_DATA = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXTENSION = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + BAD_GATEWAY = 1014 + TLS_HANDSHAKE = 1015 + + +# See https://www.iana.org/assignments/websocket/websocket.xhtml +CLOSE_CODE_EXPLANATIONS: dict[int, str] = { + CloseCode.NORMAL_CLOSURE: "OK", + CloseCode.GOING_AWAY: "going away", + CloseCode.PROTOCOL_ERROR: "protocol error", + CloseCode.UNSUPPORTED_DATA: "unsupported data", + CloseCode.NO_STATUS_RCVD: "no status received [internal]", + CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]", + CloseCode.INVALID_DATA: "invalid frame payload data", + CloseCode.POLICY_VIOLATION: "policy violation", + CloseCode.MESSAGE_TOO_BIG: "message too big", + CloseCode.MANDATORY_EXTENSION: "mandatory extension", + CloseCode.INTERNAL_ERROR: "internal error", + CloseCode.SERVICE_RESTART: "service restart", + CloseCode.TRY_AGAIN_LATER: "try again later", + CloseCode.BAD_GATEWAY: "bad gateway", + CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]", +} + + +# Close code that are allowed in a close frame. +# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. +EXTERNAL_CLOSE_CODES = { + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.PROTOCOL_ERROR, + CloseCode.UNSUPPORTED_DATA, + CloseCode.INVALID_DATA, + CloseCode.POLICY_VIOLATION, + CloseCode.MESSAGE_TOO_BIG, + CloseCode.MANDATORY_EXTENSION, + CloseCode.INTERNAL_ERROR, + CloseCode.SERVICE_RESTART, + CloseCode.TRY_AGAIN_LATER, + CloseCode.BAD_GATEWAY, +} + + +OK_CLOSE_CODES = { + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.NO_STATUS_RCVD, +} + + +BytesLike = bytes, bytearray, memoryview + + +@dataclasses.dataclass +class Frame: + """ + WebSocket frame. + + Attributes: + opcode: Opcode. + data: Payload data. + fin: FIN bit. + rsv1: RSV1 bit. + rsv2: RSV2 bit. + rsv3: RSV3 bit. + + Only these fields are needed. The MASK bit, payload length and masking-key + are handled on the fly when parsing and serializing frames. + + """ + + opcode: Opcode + data: Union[bytes, bytearray, memoryview] + fin: bool = True + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + # Configure if you want to see more in logs. Should be a multiple of 3. + MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75")) + + def __str__(self) -> str: + """ + Return a human-readable representation of a frame. + + """ + coding = None + length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}" + non_final = "" if self.fin else "continued" + + if self.opcode is OP_TEXT: + # Decoding only the beginning and the end is needlessly hard. + # Decode the entire payload then elide later if necessary. + data = repr(bytes(self.data).decode()) + elif self.opcode is OP_BINARY: + # We'll show at most the first 16 bytes and the last 8 bytes. + # Encode just what we need, plus two dummy bytes to elide later. + binary = self.data + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) + data = " ".join(f"{byte:02x}" for byte in binary) + elif self.opcode is OP_CLOSE: + data = str(Close.parse(self.data)) + elif self.data: + # We don't know if a Continuation frame contains text or binary. + # Ping and Pong frames could contain UTF-8. + # Attempt to decode as UTF-8 and display it as text; fallback to + # binary. If self.data is a memoryview, it has no decode() method, + # which raises AttributeError. + try: + data = repr(bytes(self.data).decode()) + coding = "text" + except (UnicodeDecodeError, AttributeError): + binary = self.data + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) + data = " ".join(f"{byte:02x}" for byte in binary) + coding = "binary" + else: + data = "''" + + if len(data) > self.MAX_LOG_SIZE: + cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24 + data = data[: 2 * cut] + "..." + data[-cut:] + + metadata = ", ".join(filter(None, [coding, length, non_final])) + + return f"{self.opcode.name} {data} [{metadata}]" + + @classmethod + def parse( + cls, + read_exact: Callable[[int], Generator[None, None, bytes]], + *, + mask: bool, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, + ) -> Generator[None, None, Frame]: + """ + Parse a WebSocket frame. + + This is a generator-based coroutine. + + Args: + read_exact: Generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + mask: Whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. + + Raises: + EOFError: If the connection is closed without a full WebSocket frame. + PayloadTooBig: If the frame's payload size exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. + + """ + # Read the header. + data = yield from read_exact(2) + head1, head2 = struct.unpack("!BB", data) + + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + + try: + opcode = Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc + + if (True if head2 & 0b10000000 else False) != mask: + raise ProtocolError("incorrect masking") + + length = head2 & 0b01111111 + if length == 126: + data = yield from read_exact(2) + (length,) = struct.unpack("!H", data) + elif length == 127: + data = yield from read_exact(8) + (length,) = struct.unpack("!Q", data) + if max_size is not None and length > max_size: + raise PayloadTooBig(length, max_size) + if mask: + mask_bytes = yield from read_exact(4) + + # Read the data. + data = yield from read_exact(length) + if mask: + data = apply_mask(data, mask_bytes) + + frame = cls(opcode, data, fin, rsv1, rsv2, rsv3) + + if extensions is None: + extensions = [] + for extension in reversed(extensions): + frame = extension.decode(frame, max_size=max_size) + + frame.check() + + return frame + + def serialize( + self, + *, + mask: bool, + extensions: Sequence[extensions.Extension] | None = None, + ) -> bytes: + """ + Serialize a WebSocket frame. + + Args: + mask: Whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: List of extensions, applied in order. + + Raises: + ProtocolError: If the frame contains incorrect values. + + """ + self.check() + + if extensions is None: + extensions = [] + for extension in extensions: + self = extension.encode(self) + + output = io.BytesIO() + + # Prepare the header. + head1 = ( + (0b10000000 if self.fin else 0) + | (0b01000000 if self.rsv1 else 0) + | (0b00100000 if self.rsv2 else 0) + | (0b00010000 if self.rsv3 else 0) + | self.opcode + ) + + head2 = 0b10000000 if mask else 0 + + length = len(self.data) + if length < 126: + output.write(struct.pack("!BB", head1, head2 | length)) + elif length < 65536: + output.write(struct.pack("!BBH", head1, head2 | 126, length)) + else: + output.write(struct.pack("!BBQ", head1, head2 | 127, length)) + + if mask: + mask_bytes = secrets.token_bytes(4) + output.write(mask_bytes) + + # Prepare the data. + if mask: + data = apply_mask(self.data, mask_bytes) + else: + data = self.data + output.write(data) + + return output.getvalue() + + def check(self) -> None: + """ + Check that reserved bits and opcode have acceptable values. + + Raises: + ProtocolError: If a reserved bit or the opcode is invalid. + + """ + if self.rsv1 or self.rsv2 or self.rsv3: + raise ProtocolError("reserved bits must be 0") + + if self.opcode in CTRL_OPCODES: + if len(self.data) > 125: + raise ProtocolError("control frame too long") + if not self.fin: + raise ProtocolError("fragmented control frame") + + +@dataclasses.dataclass +class Close: + """ + Code and reason for WebSocket close frames. + + Attributes: + code: Close code. + reason: Close reason. + + """ + + code: int + reason: str + + def __str__(self) -> str: + """ + Return a human-readable representation of a close code and reason. + + """ + if 3000 <= self.code < 4000: + explanation = "registered" + elif 4000 <= self.code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown") + result = f"{self.code} ({explanation})" + + if self.reason: + result = f"{result} {self.reason}" + + return result + + @classmethod + def parse(cls, data: bytes) -> Close: + """ + Parse the payload of a close frame. + + Args: + data: Payload of the close frame. + + Raises: + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. + + """ + if len(data) >= 2: + (code,) = struct.unpack("!H", data[:2]) + reason = data[2:].decode() + close = cls(code, reason) + close.check() + return close + elif len(data) == 0: + return cls(CloseCode.NO_STATUS_RCVD, "") + else: + raise ProtocolError("close frame too short") + + def serialize(self) -> bytes: + """ + Serialize the payload of a close frame. + + """ + self.check() + return struct.pack("!H", self.code) + self.reason.encode() + + def check(self) -> None: + """ + Check that the close code has a valid value for a close frame. + + Raises: + ProtocolError: If the close code is invalid. + + """ + if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): + raise ProtocolError("invalid status code") + + +# At the bottom to break import cycles created by type annotations. +from . import extensions # noqa: E402 diff --git a/.venv/lib/python3.11/site-packages/websockets/headers.py b/.venv/lib/python3.11/site-packages/websockets/headers.py new file mode 100644 index 0000000000000000000000000000000000000000..e05948a1f99676d07ee8c39b0c803c24c753652f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/headers.py @@ -0,0 +1,580 @@ +from __future__ import annotations + +import base64 +import binascii +import ipaddress +import re +from collections.abc import Sequence +from typing import Callable, TypeVar, cast + +from .exceptions import InvalidHeaderFormat, InvalidHeaderValue +from .typing import ( + ConnectionOption, + ExtensionHeader, + ExtensionName, + ExtensionParameter, + Subprotocol, + UpgradeProtocol, +) + + +__all__ = [ + "build_host", + "parse_connection", + "parse_upgrade", + "parse_extension", + "build_extension", + "parse_subprotocol", + "build_subprotocol", + "validate_subprotocols", + "build_www_authenticate_basic", + "parse_authorization_basic", + "build_authorization_basic", +] + + +T = TypeVar("T") + + +def build_host(host: str, port: int, secure: bool) -> str: + """ + Build a ``Host`` header. + + """ + # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 + # IPv6 addresses must be enclosed in brackets. + try: + address = ipaddress.ip_address(host) + except ValueError: + # host is a hostname + pass + else: + # host is an IP address + if address.version == 6: + host = f"[{host}]" + + if port != (443 if secure else 80): + host = f"{host}:{port}" + + return host + + +# To avoid a dependency on a parsing library, we implement manually the ABNF +# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and +# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. + + +def peek_ahead(header: str, pos: int) -> str | None: + """ + Return the next character from ``header`` at the given position. + + Return :obj:`None` at the end of ``header``. + + We never need to peek more than one character ahead. + + """ + return None if pos == len(header) else header[pos] + + +_OWS_re = re.compile(r"[\t ]*") + + +def parse_OWS(header: str, pos: int) -> int: + """ + Parse optional whitespace from ``header`` at the given position. + + Return the new position. + + The whitespace itself isn't returned because it isn't significant. + + """ + # There's always a match, possibly empty, whose content doesn't matter. + match = _OWS_re.match(header, pos) + assert match is not None + return match.end() + + +_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + + +def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: + """ + Parse a token from ``header`` at the given position. + + Return the token value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _token_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected token", header, pos) + return match.group(), match.end() + + +_quoted_string_re = re.compile( + r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"' +) + + +_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") + + +def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]: + """ + Parse a quoted string from ``header`` at the given position. + + Return the unquoted value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _quoted_string_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) + return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() + + +_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*") + + +_quote_re = re.compile(r"([\x22\x5c])") + + +def build_quoted_string(value: str) -> str: + """ + Format ``value`` as a quoted string. + + This is the reverse of :func:`parse_quoted_string`. + + """ + match = _quotable_re.fullmatch(value) + if match is None: + raise ValueError("invalid characters for quoted-string encoding") + return '"' + _quote_re.sub(r"\\\1", value) + '"' + + +def parse_list( + parse_item: Callable[[str, int, str], tuple[T, int]], + header: str, + pos: int, + header_name: str, +) -> list[T]: + """ + Parse a comma-separated list from ``header`` at the given position. + + This is appropriate for parsing values with the following grammar: + + 1#item + + ``parse_item`` parses one item. + + ``header`` is assumed not to start or end with whitespace. + + (This function is designed for parsing an entire header value and + :func:`~websockets.http.read_headers` strips whitespace from values.) + + Return a list of items. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient + # MUST parse and ignore a reasonable number of empty list elements"; + # hence while loops that remove extra delimiters. + + # Remove extra delimiters before the first item. + while peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) + + items = [] + while True: + # Loop invariant: a item starts at pos in header. + item, pos = parse_item(header, pos, header_name) + items.append(item) + pos = parse_OWS(header, pos) + + # We may have reached the end of the header. + if pos == len(header): + break + + # There must be a delimiter after each element except the last one. + if peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) + else: + raise InvalidHeaderFormat(header_name, "expected comma", header, pos) + + # Remove extra delimiters before the next item. + while peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) + + # We may have reached the end of the header. + if pos == len(header): + break + + # Since we only advance in the header by one character with peek_ahead() + # or with the end position of a regex match, we can't overshoot the end. + assert pos == len(header) + + return items + + +def parse_connection_option( + header: str, pos: int, header_name: str +) -> tuple[ConnectionOption, int]: + """ + Parse a Connection option from ``header`` at the given position. + + Return the protocol value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + item, pos = parse_token(header, pos, header_name) + return cast(ConnectionOption, item), pos + + +def parse_connection(header: str) -> list[ConnectionOption]: + """ + Parse a ``Connection`` header. + + Return a list of HTTP connection options. + + Args + header: value of the ``Connection`` header. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_connection_option, header, 0, "Connection") + + +_protocol_re = re.compile( + r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?" +) + + +def parse_upgrade_protocol( + header: str, pos: int, header_name: str +) -> tuple[UpgradeProtocol, int]: + """ + Parse an Upgrade protocol from ``header`` at the given position. + + Return the protocol value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _protocol_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) + return cast(UpgradeProtocol, match.group()), match.end() + + +def parse_upgrade(header: str) -> list[UpgradeProtocol]: + """ + Parse an ``Upgrade`` header. + + Return a list of HTTP protocols. + + Args: + header: Value of the ``Upgrade`` header. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") + + +def parse_extension_item_param( + header: str, pos: int, header_name: str +) -> tuple[ExtensionParameter, int]: + """ + Parse a single extension parameter from ``header`` at the given position. + + Return a ``(name, value)`` pair and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + # Extract parameter name. + name, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) + # Extract parameter value, if there is one. + value: str | None = None + if peek_ahead(header, pos) == "=": + pos = parse_OWS(header, pos + 1) + if peek_ahead(header, pos) == '"': + pos_before = pos # for proper error reporting below + value, pos = parse_quoted_string(header, pos, header_name) + # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says: + # the value after quoted-string unescaping MUST conform to + # the 'token' ABNF. + if _token_re.fullmatch(value) is None: + raise InvalidHeaderFormat( + header_name, "invalid quoted header content", header, pos_before + ) + else: + value, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) + + return (name, value), pos + + +def parse_extension_item( + header: str, pos: int, header_name: str +) -> tuple[ExtensionHeader, int]: + """ + Parse an extension definition from ``header`` at the given position. + + Return an ``(extension name, parameters)`` pair, where ``parameters`` is a + list of ``(name, value)`` pairs, and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + # Extract extension name. + name, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) + # Extract all parameters. + parameters = [] + while peek_ahead(header, pos) == ";": + pos = parse_OWS(header, pos + 1) + parameter, pos = parse_extension_item_param(header, pos, header_name) + parameters.append(parameter) + return (cast(ExtensionName, name), parameters), pos + + +def parse_extension(header: str) -> list[ExtensionHeader]: + """ + Parse a ``Sec-WebSocket-Extensions`` header. + + Return a list of WebSocket extensions and their parameters in this format:: + + [ + ( + 'extension name', + [ + ('parameter name', 'parameter value'), + .... + ] + ), + ... + ] + + Parameter values are :obj:`None` when no value is provided. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") + + +parse_extension_list = parse_extension # alias for backwards compatibility + + +def build_extension_item( + name: ExtensionName, parameters: list[ExtensionParameter] +) -> str: + """ + Build an extension definition. + + This is the reverse of :func:`parse_extension_item`. + + """ + return "; ".join( + [cast(str, name)] + + [ + # Quoted strings aren't necessary because values are always tokens. + name if value is None else f"{name}={value}" + for name, value in parameters + ] + ) + + +def build_extension(extensions: Sequence[ExtensionHeader]) -> str: + """ + Build a ``Sec-WebSocket-Extensions`` header. + + This is the reverse of :func:`parse_extension`. + + """ + return ", ".join( + build_extension_item(name, parameters) for name, parameters in extensions + ) + + +build_extension_list = build_extension # alias for backwards compatibility + + +def parse_subprotocol_item( + header: str, pos: int, header_name: str +) -> tuple[Subprotocol, int]: + """ + Parse a subprotocol from ``header`` at the given position. + + Return the subprotocol value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + item, pos = parse_token(header, pos, header_name) + return cast(Subprotocol, item), pos + + +def parse_subprotocol(header: str) -> list[Subprotocol]: + """ + Parse a ``Sec-WebSocket-Protocol`` header. + + Return a list of WebSocket subprotocols. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") + + +parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility + + +def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str: + """ + Build a ``Sec-WebSocket-Protocol`` header. + + This is the reverse of :func:`parse_subprotocol`. + + """ + return ", ".join(subprotocols) + + +build_subprotocol_list = build_subprotocol # alias for backwards compatibility + + +def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None: + """ + Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`. + + """ + if not isinstance(subprotocols, Sequence): + raise TypeError("subprotocols must be a list") + if isinstance(subprotocols, str): + raise TypeError("subprotocols must be a list, not a str") + for subprotocol in subprotocols: + if not _token_re.fullmatch(subprotocol): + raise ValueError(f"invalid subprotocol: {subprotocol}") + + +def build_www_authenticate_basic(realm: str) -> str: + """ + Build a ``WWW-Authenticate`` header for HTTP Basic Auth. + + Args: + realm: Identifier of the protection space. + + """ + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + realm = build_quoted_string(realm) + charset = build_quoted_string("UTF-8") + return f"Basic realm={realm}, charset={charset}" + + +_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") + + +def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: + """ + Parse a token68 from ``header`` at the given position. + + Return the token value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _token68_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected token68", header, pos) + return match.group(), match.end() + + +def parse_end(header: str, pos: int, header_name: str) -> None: + """ + Check that parsing reached the end of header. + + """ + if pos < len(header): + raise InvalidHeaderFormat(header_name, "trailing data", header, pos) + + +def parse_authorization_basic(header: str) -> tuple[str, str]: + """ + Parse an ``Authorization`` header for HTTP Basic Auth. + + Return a ``(username, password)`` tuple. + + Args: + header: Value of the ``Authorization`` header. + + Raises: + InvalidHeaderFormat: On invalid inputs. + InvalidHeaderValue: On unsupported inputs. + + """ + # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + scheme, pos = parse_token(header, 0, "Authorization") + if scheme.lower() != "basic": + raise InvalidHeaderValue( + "Authorization", + f"unsupported scheme: {scheme}", + ) + if peek_ahead(header, pos) != " ": + raise InvalidHeaderFormat( + "Authorization", "expected space after scheme", header, pos + ) + pos += 1 + basic_credentials, pos = parse_token68(header, pos, "Authorization") + parse_end(header, pos, "Authorization") + + try: + user_pass = base64.b64decode(basic_credentials.encode()).decode() + except binascii.Error: + raise InvalidHeaderValue( + "Authorization", + "expected base64-encoded credentials", + ) from None + try: + username, password = user_pass.split(":", 1) + except ValueError: + raise InvalidHeaderValue( + "Authorization", + "expected username:password credentials", + ) from None + + return username, password + + +def build_authorization_basic(username: str, password: str) -> str: + """ + Build an ``Authorization`` header for HTTP Basic Auth. + + This is the reverse of :func:`parse_authorization_basic`. + + """ + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + assert ":" not in username + user_pass = f"{username}:{password}" + basic_credentials = base64.b64encode(user_pass.encode()).decode() + return "Basic " + basic_credentials diff --git a/.venv/lib/python3.11/site-packages/websockets/http.py b/.venv/lib/python3.11/site-packages/websockets/http.py new file mode 100644 index 0000000000000000000000000000000000000000..0d860e5379404c12f8fb4177ca4fcb6764b86f3b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/http.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import warnings + +from .datastructures import Headers, MultipleValuesError # noqa: F401 + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.http import read_request, read_response # noqa: F401 + + +warnings.warn( # deprecated in 9.0 - 2021-09-01 + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + DeprecationWarning, +) diff --git a/.venv/lib/python3.11/site-packages/websockets/http11.py b/.venv/lib/python3.11/site-packages/websockets/http11.py new file mode 100644 index 0000000000000000000000000000000000000000..49d7b9a41dbf438b26c22b4a379117d92a1fd1ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/http11.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +import dataclasses +import os +import re +import sys +import warnings +from collections.abc import Generator +from typing import Callable + +from .datastructures import Headers +from .exceptions import SecurityError +from .version import version as websockets_version + + +__all__ = [ + "SERVER", + "USER_AGENT", + "Request", + "Response", +] + + +PYTHON_VERSION = "{}.{}".format(*sys.version_info) + +# User-Agent header for HTTP requests. +USER_AGENT = os.environ.get( + "WEBSOCKETS_USER_AGENT", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Server header for HTTP responses. +SERVER = os.environ.get( + "WEBSOCKETS_SERVER", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Maximum total size of headers is around 128 * 8 KiB = 1 MiB. +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) + +# Limit request line and header lines. 8KiB is the most common default +# configuration of popular HTTP servers. +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) + +# Support for HTTP response bodies is intended to read an error message +# returned by a server. It isn't designed to perform large file transfers. +MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB + + +def d(value: bytes) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") + + +@dataclasses.dataclass +class Request: + """ + WebSocket handshake request. + + Attributes: + path: Request path, including optional query. + headers: Request headers. + """ + + path: str + headers: Headers + # body isn't useful is the context of this library. + + _exception: Exception | None = None + + @property + def exception(self) -> Exception | None: # pragma: no cover + warnings.warn( # deprecated in 10.3 - 2022-04-17 + "Request.exception is deprecated; use ServerProtocol.handshake_exc instead", + DeprecationWarning, + ) + return self._exception + + @classmethod + def parse( + cls, + read_line: Callable[[int], Generator[None, None, bytes]], + ) -> Generator[None, None, Request]: + """ + Parse a WebSocket handshake request. + + This is a generator-based coroutine. + + The request path isn't URL-decoded or validated in any way. + + The request path and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. + + :meth:`parse` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from the data stream after :meth:`parse` returns. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + + Raises: + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + try: + request_line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc + + try: + method, raw_path, protocol = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(request_line)}" + ) + if method != b"GET": + raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}") + path = raw_path.decode("ascii", "surrogateescape") + + headers = yield from parse_headers(read_line) + + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 + + if "Transfer-Encoding" in headers: + raise NotImplementedError("transfer codings aren't supported") + + if "Content-Length" in headers: + raise ValueError("unsupported request body") + + return cls(path, headers) + + def serialize(self) -> bytes: + """ + Serialize a WebSocket handshake request. + + """ + # Since the request line and headers only contain ASCII characters, + # we can keep this simple. + request = f"GET {self.path} HTTP/1.1\r\n".encode() + request += self.headers.serialize() + return request + + +@dataclasses.dataclass +class Response: + """ + WebSocket handshake response. + + Attributes: + status_code: Response code. + reason_phrase: Response reason. + headers: Response headers. + body: Response body. + + """ + + status_code: int + reason_phrase: str + headers: Headers + body: bytes = b"" + + _exception: Exception | None = None + + @property + def exception(self) -> Exception | None: # pragma: no cover + warnings.warn( # deprecated in 10.3 - 2022-04-17 + "Response.exception is deprecated; " + "use ClientProtocol.handshake_exc instead", + DeprecationWarning, + ) + return self._exception + + @classmethod + def parse( + cls, + read_line: Callable[[int], Generator[None, None, bytes]], + read_exact: Callable[[int], Generator[None, None, bytes]], + read_to_eof: Callable[[int], Generator[None, None, bytes]], + ) -> Generator[None, None, Response]: + """ + Parse a WebSocket handshake response. + + This is a generator-based coroutine. + + The reason phrase and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data. + read_exact: Generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + read_to_eof: Generator-based coroutine that reads until the end + of the stream. + + Raises: + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + LookupError: If the response isn't well formatted. + ValueError: If the response isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 + + try: + status_line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc + + try: + protocol, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(status_line)}" + ) + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError( + f"invalid status code; expected integer; got {d(raw_status_code)}" + ) from None + if not 100 <= status_code < 600: + raise ValueError( + f"invalid status code; expected 100–599; got {d(raw_status_code)}" + ) + if not _value_re.fullmatch(raw_reason): + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") + reason = raw_reason.decode("ascii", "surrogateescape") + + headers = yield from parse_headers(read_line) + + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) + + return cls(status_code, reason, headers, body) + + def serialize(self) -> bytes: + """ + Serialize a WebSocket handshake response. + + """ + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode() + response += self.headers.serialize() + response += self.body + return response + + +def parse_line( + read_line: Callable[[int], Generator[None, None, bytes]], +) -> Generator[None, None, bytes]: + """ + Parse a single line. + + CRLF is stripped from the return value. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. + + """ + try: + line = yield from read_line(MAX_LINE_LENGTH) + except RuntimeError: + raise SecurityError("line too long") + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] + + +def parse_headers( + read_line: Callable[[int], Generator[None, None, bytes]], +) -> Generator[None, None, Headers]: + """ + Parse HTTP headers. + + Non-ASCII characters are represented with surrogate escapes. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: If the connection is closed without complete headers. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = Headers() + for _ in range(MAX_NUM_HEADERS + 1): + try: + line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc + if line == b"": + break + + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None + if not _token_re.fullmatch(raw_name): + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") + headers[name] = value + + else: + raise SecurityError("too many HTTP headers") + + return headers + + +def read_body( + status_code: int, + headers: Headers, + read_line: Callable[[int], Generator[None, None, bytes]], + read_exact: Callable[[int], Generator[None, None, bytes]], + read_to_eof: Callable[[int], Generator[None, None, bytes]], +) -> Generator[None, None, bytes]: + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 + + # Since websockets only does GET requests (no HEAD, no CONNECT), all + # responses except 1xx, 204, and 304 include a message body. + if 100 <= status_code < 200 or status_code == 204 or status_code == 304: + return b"" + + # MultipleValuesError is sufficiently unlikely that we don't attempt to + # handle it when accessing headers. Instead we document that its parent + # class, LookupError, may be raised. + # Conversions from str to int are protected by sys.set_int_max_str_digits.. + + elif (coding := headers.get("Transfer-Encoding")) is not None: + if coding != "chunked": + raise NotImplementedError(f"transfer coding {coding} isn't supported") + + body = b"" + while True: + chunk_size_line = yield from parse_line(read_line) + raw_chunk_size = chunk_size_line.split(b";", 1)[0] + # Set a lower limit than default_max_str_digits; 1 EB is plenty. + if len(raw_chunk_size) > 15: + str_chunk_size = raw_chunk_size.decode(errors="backslashreplace") + raise SecurityError(f"chunk too large: 0x{str_chunk_size} bytes") + chunk_size = int(raw_chunk_size, 16) + if chunk_size == 0: + break + if len(body) + chunk_size > MAX_BODY_SIZE: + raise SecurityError( + f"chunk too large: {chunk_size} bytes after {len(body)} bytes" + ) + body += yield from read_exact(chunk_size) + if (yield from read_exact(2)) != b"\r\n": + raise ValueError("chunk without CRLF") + # Read the trailer. + yield from parse_headers(read_line) + return body + + elif (raw_content_length := headers.get("Content-Length")) is not None: + # Set a lower limit than default_max_str_digits; 1 EiB is plenty. + if len(raw_content_length) > 18: + raise SecurityError(f"body too large: {raw_content_length} bytes") + content_length = int(raw_content_length) + if content_length > MAX_BODY_SIZE: + raise SecurityError(f"body too large: {content_length} bytes") + return (yield from read_exact(content_length)) + + else: + try: + return (yield from read_to_eof(MAX_BODY_SIZE)) + except RuntimeError: + raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") diff --git a/.venv/lib/python3.11/site-packages/websockets/imports.py b/.venv/lib/python3.11/site-packages/websockets/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..c63fb212ec602ae6ec75fe1b86a29fb2e11334df --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/imports.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import warnings +from collections.abc import Iterable +from typing import Any + + +__all__ = ["lazy_import"] + + +def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any: + """ + Import ``name`` from ``source`` in ``namespace``. + + There are two use cases: + + - ``name`` is an object defined in ``source``; + - ``name`` is a submodule of ``source``. + + Neither :func:`__import__` nor :func:`~importlib.import_module` does + exactly this. :func:`__import__` is closer to the intended behavior. + + """ + level = 0 + while source[level] == ".": + level += 1 + assert level < len(source), "importing from parent isn't supported" + module = __import__(source[level:], namespace, None, [name], level) + return getattr(module, name) + + +def lazy_import( + namespace: dict[str, Any], + aliases: dict[str, str] | None = None, + deprecated_aliases: dict[str, str] | None = None, +) -> None: + """ + Provide lazy, module-level imports. + + Typical use:: + + __getattr__, __dir__ = lazy_import( + globals(), + aliases={ + "": "", + ... + }, + deprecated_aliases={ + ..., + } + ) + + This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`. + + """ + if aliases is None: + aliases = {} + if deprecated_aliases is None: + deprecated_aliases = {} + + namespace_set = set(namespace) + aliases_set = set(aliases) + deprecated_aliases_set = set(deprecated_aliases) + + assert not namespace_set & aliases_set, "namespace conflict" + assert not namespace_set & deprecated_aliases_set, "namespace conflict" + assert not aliases_set & deprecated_aliases_set, "namespace conflict" + + package = namespace["__name__"] + + def __getattr__(name: str) -> Any: + assert aliases is not None # mypy cannot figure this out + try: + source = aliases[name] + except KeyError: + pass + else: + return import_name(name, source, namespace) + + assert deprecated_aliases is not None # mypy cannot figure this out + try: + source = deprecated_aliases[name] + except KeyError: + pass + else: + warnings.warn( + f"{package}.{name} is deprecated", + DeprecationWarning, + stacklevel=2, + ) + return import_name(name, source, namespace) + + raise AttributeError(f"module {package!r} has no attribute {name!r}") + + namespace["__getattr__"] = __getattr__ + + def __dir__() -> Iterable[str]: + return sorted(namespace_set | aliases_set | deprecated_aliases_set) + + namespace["__dir__"] = __dir__ diff --git a/.venv/lib/python3.11/site-packages/websockets/protocol.py b/.venv/lib/python3.11/site-packages/websockets/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..bc64a216ad1beb045eb552eb0c7bbee186122f74 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/protocol.py @@ -0,0 +1,758 @@ +from __future__ import annotations + +import enum +import logging +import uuid +from collections.abc import Generator +from typing import Union + +from .exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from .extensions import Extension +from .frames import ( + OK_CLOSE_CODES, + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + CloseCode, + Frame, +) +from .http11 import Request, Response +from .streams import StreamReader +from .typing import LoggerLike, Origin, Subprotocol + + +__all__ = [ + "Protocol", + "Side", + "State", + "SEND_EOF", +] + +# Change to Request | Response | Frame when dropping Python < 3.10. +Event = Union[Request, Response, Frame] +"""Events that :meth:`~Protocol.events_received` may return.""" + + +class Side(enum.IntEnum): + """A WebSocket connection is either a server or a client.""" + + SERVER, CLIENT = range(2) + + +SERVER = Side.SERVER +CLIENT = Side.CLIENT + + +class State(enum.IntEnum): + """A WebSocket connection is in one of these four states.""" + + CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +CONNECTING = State.CONNECTING +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + +SEND_EOF = b"" +"""Sentinel signaling that the TCP connection must be half-closed.""" + + +class Protocol: + """ + Sans-I/O implementation of a WebSocket connection. + + Args: + side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; + :obj:`None` disables the limit. + logger: Logger for this connection; depending on ``side``, + defaults to ``logging.getLogger("websockets.client")`` + or ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + side: Side, + *, + state: State = OPEN, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, + ) -> None: + # Unique identifier. For logs. + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" + + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger: LoggerLike = logger + """Logger for this connection.""" + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + + # Connection side. CLIENT or SERVER. + self.side = side + + # Connection state. Initially OPEN because subclasses handle CONNECTING. + self.state = state + + # Maximum size of incoming messages in bytes. + self.max_size = max_size + + # Current size of incoming message in bytes. Only set while reading a + # fragmented message i.e. a data frames with the FIN bit not set. + self.cur_size: int | None = None + + # True while sending a fragmented message i.e. a data frames with the + # FIN bit not set. + self.expect_continuation_frame = False + + # WebSocket protocol parameters. + self.origin: Origin | None = None + self.extensions: list[Extension] = [] + self.subprotocol: Subprotocol | None = None + + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None + + # Track if an exception happened during the handshake. + self.handshake_exc: Exception | None = None + """ + Exception to raise if the opening handshake failed. + + :obj:`None` if the opening handshake succeeded. + + """ + + # Track if send_eof() was called. + self.eof_sent = False + + # Parser state. + self.reader = StreamReader() + self.events: list[Event] = [] + self.writes: list[bytes] = [] + self.parser = self.parse() + next(self.parser) # start coroutine + self.parser_exc: Exception | None = None + + @property + def state(self) -> State: + """ + State of the WebSocket connection. + + Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. + + .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 + .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 + .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 + .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 + + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self._state = state + + @property + def close_code(self) -> int | None: + """ + WebSocket close code received from the remote endpoint. + + Defined in 7.1.5_ of :rfc:`6455`. + + .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return CloseCode.ABNORMAL_CLOSURE + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> str | None: + """ + WebSocket close reason received from the remote endpoint. + + Defined in 7.1.6_ of :rfc:`6455`. + + .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + + @property + def close_exc(self) -> ConnectionClosed: + """ + Exception to raise when trying to interact with a closed connection. + + Don't raise this exception while the connection :attr:`state` + is :attr:`~websockets.protocol.State.CLOSING`; wait until + it's :attr:`~websockets.protocol.State.CLOSED`. + + Indeed, the exception includes the close code and reason, which are + known only once the connection is closed. + + Raises: + AssertionError: If the connection isn't closed yet. + + """ + assert self.state is CLOSED, "connection isn't closed yet" + exc_type: type[ConnectionClosed] + if ( + self.close_rcvd is not None + and self.close_sent is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent.code in OK_CLOSE_CODES + ): + exc_type = ConnectionClosedOK + else: + exc_type = ConnectionClosedError + exc: ConnectionClosed = exc_type( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception raised in the parser, if any. + exc.__cause__ = self.parser_exc + return exc + + # Public methods for receiving data. + + def receive_data(self, data: bytes) -> None: + """ + Receive data from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network. + - You should call :meth:`events_received` and process resulting events. + + Raises: + EOFError: If :meth:`receive_eof` was called earlier. + + """ + self.reader.feed_data(data) + next(self.parser) + + def receive_eof(self) -> None: + """ + Receive the end of the data stream from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network; + it will return ``[b""]``, signaling the end of the stream, or ``[]``. + - You aren't expected to call :meth:`events_received`; it won't return + any new events. + + :meth:`receive_eof` is idempotent. + + """ + if self.reader.eof: + return + self.reader.feed_eof() + next(self.parser) + + # Public methods for sending events. + + def send_continuation(self, data: bytes, fin: bool) -> None: + """ + Send a `Continuation frame`_. + + .. _Continuation frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing the same kind of data + as the initial frame. + fin: FIN bit; set it to :obj:`True` if this is the last frame + of a fragmented message and to :obj:`False` otherwise. + + Raises: + ProtocolError: If a fragmented message isn't in progress. + + """ + if not self.expect_continuation_frame: + raise ProtocolError("unexpected continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_CONT, data, fin)) + + def send_text(self, data: bytes, fin: bool = True) -> None: + """ + Send a `Text frame`_. + + .. _Text frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing text encoded with UTF-8. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: If a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_TEXT, data, fin)) + + def send_binary(self, data: bytes, fin: bool = True) -> None: + """ + Send a `Binary frame`_. + + .. _Binary frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing arbitrary binary data. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: If a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_BINARY, data, fin)) + + def send_close(self, code: int | None = None, reason: str = "") -> None: + """ + Send a `Close frame`_. + + .. _Close frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + + Parameters: + code: close code. + reason: close reason. + + Raises: + ProtocolError: If the code isn't valid or if a reason is provided + without a code. + + """ + # While RFC 6455 doesn't rule out sending more than one close Frame, + # websockets is conservative in what it sends and doesn't allow that. + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + if code is None: + if reason != "": + raise ProtocolError("cannot send a reason without a code") + close = Close(CloseCode.NO_STATUS_RCVD, "") + data = b"" + else: + close = Close(code, reason) + data = close.serialize() + # 7.1.3. The WebSocket Closing Handshake is Started + self.send_frame(Frame(OP_CLOSE, data)) + # Since the state is OPEN, no close frame was received yet. + # As a consequence, self.close_rcvd_then_sent remains None. + assert self.close_rcvd is None + self.close_sent = close + self.state = CLOSING + + def send_ping(self, data: bytes) -> None: + """ + Send a `Ping frame`_. + + .. _Ping frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + Parameters: + data: payload containing arbitrary binary data. + + """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.send_frame(Frame(OP_PING, data)) + + def send_pong(self, data: bytes) -> None: + """ + Send a `Pong frame`_. + + .. _Pong frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + Parameters: + data: payload containing arbitrary binary data. + + """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.send_frame(Frame(OP_PONG, data)) + + def fail(self, code: int, reason: str = "") -> None: + """ + `Fail the WebSocket connection`_. + + .. _Fail the WebSocket connection: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 + + Parameters: + code: close code + reason: close reason + + Raises: + ProtocolError: If the code isn't valid. + """ + # 7.1.7. Fail the WebSocket Connection + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because + # of an error reading from or writing to the network. + if self.state is OPEN: + if code != CloseCode.ABNORMAL_CLOSURE: + close = Close(code, reason) + data = close.serialize() + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + # If recv_messages() raised an exception upon receiving a close + # frame but before echoing it, then close_rcvd is not None even + # though the state is OPEN. This happens when the connection is + # closed while receiving a fragmented message. + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True + self.state = CLOSING + + # When failing the connection, a server closes the TCP connection + # without waiting for the client to complete the handshake, while a + # client waits for the server to close the TCP connection, possibly + # after sending a close frame that the client will ignore. + if self.side is SERVER and not self.eof_sent: + self.send_eof() + + # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue + # to attempt to process data(including a responding Close frame) from + # the remote endpoint after being instructed to _Fail the WebSocket + # Connection_." + self.parser = self.discard() + next(self.parser) # start coroutine + + # Public method for getting incoming events after receiving data. + + def events_received(self) -> list[Event]: + """ + Fetch events generated from data received from the network. + + Call this method immediately after any of the ``receive_*()`` methods. + + Process resulting events, likely by passing them to the application. + + Returns: + Events read from the connection. + """ + events, self.events = self.events, [] + return events + + # Public method for getting outgoing data after receiving data or sending events. + + def data_to_send(self) -> list[bytes]: + """ + Obtain data to send to the network. + + Call this method immediately after any of the ``receive_*()``, + ``send_*()``, or :meth:`fail` methods. + + Write resulting data to the connection. + + The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals + the end of the data stream. When you receive it, half-close the TCP + connection. + + Returns: + Data to write to the connection. + + """ + writes, self.writes = self.writes, [] + return writes + + def close_expected(self) -> bool: + """ + Tell if the TCP connection is expected to close soon. + + Call this method immediately after any of the ``receive_*()``, + ``send_close()``, or :meth:`fail` methods. + + If it returns :obj:`True`, schedule closing the TCP connection after a + short timeout if the other side hasn't already closed it. + + Returns: + Whether the TCP connection is expected to close soon. + + """ + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + + # * Normal closure: once we send a close frame, we expect a TCP close: + # server waits for client to complete the TCP closing handshake; + # client waits for server to initiate the TCP closing handshake. + + # * Abnormal closure: we always send a close frame and the same logic + # applies, except on EOFError where we don't send a close frame + # because we already received the TCP close, so we don't expect it. + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent + + # Private methods for receiving data. + + def parse(self) -> Generator[None]: + """ + Parse incoming data into frames. + + :meth:`receive_data` and :meth:`receive_eof` run this generator + coroutine until it needs more data or reaches EOF. + + :meth:`parse` never raises an exception. Instead, it sets the + :attr:`parser_exc` and yields control. + + """ + try: + while True: + if (yield from self.reader.at_eof()): + if self.debug: + self.logger.debug("< EOF") + # If the WebSocket connection is closed cleanly, with a + # closing handhshake, recv_frame() substitutes parse() + # with discard(). This branch is reached only when the + # connection isn't closed cleanly. + raise EOFError("unexpected end of stream") + + if self.max_size is None: + max_size = None + elif self.cur_size is None: + max_size = self.max_size + else: + max_size = self.max_size - self.cur_size + + # During a normal closure, execution ends here on the next + # iteration of the loop after receiving a close frame. At + # this point, recv_frame() replaced parse() by discard(). + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if self.debug: + self.logger.debug("< %s", frame) + + self.recv_frame(frame) + + except ProtocolError as exc: + self.fail(CloseCode.PROTOCOL_ERROR, str(exc)) + self.parser_exc = exc + + except EOFError as exc: + self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc)) + self.parser_exc = exc + + except UnicodeDecodeError as exc: + self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}") + self.parser_exc = exc + + except PayloadTooBig as exc: + exc.set_current_size(self.cur_size) + self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) + self.parser_exc = exc + + except Exception as exc: + self.logger.error("parser failed", exc_info=True) + # Don't include exception details, which may be security-sensitive. + self.fail(CloseCode.INTERNAL_ERROR) + self.parser_exc = exc + + # During an abnormal closure, execution ends here after catching an + # exception. At this point, fail() replaced parse() by discard(). + yield + raise AssertionError("parse() shouldn't step after error") + + def discard(self) -> Generator[None]: + """ + Discard incoming data. + + This coroutine replaces :meth:`parse`: + + - after receiving a close frame, during a normal closure (1.4); + - after sending a close frame, during an abnormal closure (7.1.7). + + """ + # After the opening handshake completes, the server closes the TCP + # connection in the same circumstances where discard() replaces parse(). + # The client closes it when it receives EOF from the server or times + # out. (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) + while not (yield from self.reader.at_eof()): + self.reader.discard() + if self.debug: + self.logger.debug("< EOF") + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is CLIENT and self.state is not CONNECTING: + self.send_eof() + self.state = CLOSED + # If discard() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() methods raise an + # error, so our receive_data/eof() methods don't step the generator. + raise AssertionError("discard() shouldn't step after EOF") + + def recv_frame(self, frame: Frame) -> None: + """ + Process an incoming frame. + + """ + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + if self.cur_size is not None: + raise ProtocolError("expected a continuation frame") + if not frame.fin: + self.cur_size = len(frame.data) + + elif frame.opcode is OP_CONT: + if self.cur_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response" + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False + + if self.cur_size is not None: + raise ProtocolError("incomplete fragmented message") + + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthesizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True + self.state = CLOSING + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is SERVER: + self.send_eof() + + # 1.4. Closing Handshake: "after receiving a control frame + # indicating the connection should be closed, a peer discards + # any further data received." + # RFC 6455 allows reading Ping and Pong frames after a Close frame. + # However, that doesn't seem useful; websockets doesn't support it. + self.parser = self.discard() + next(self.parser) # start coroutine + + else: + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) + + # Private methods for sending events. + + def send_frame(self, frame: Frame) -> None: + if self.debug: + self.logger.debug("> %s", frame) + self.writes.append( + frame.serialize( + mask=self.side is CLIENT, + extensions=self.extensions, + ) + ) + + def send_eof(self) -> None: + assert not self.eof_sent + self.eof_sent = True + if self.debug: + self.logger.debug("> EOF") + self.writes.append(SEND_EOF) diff --git a/.venv/lib/python3.11/site-packages/websockets/py.typed b/.venv/lib/python3.11/site-packages/websockets/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/websockets/server.py b/.venv/lib/python3.11/site-packages/websockets/server.py new file mode 100644 index 0000000000000000000000000000000000000000..90e6c9921c35178e62b6ce6ae7962ae1ed1d22df --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/server.py @@ -0,0 +1,604 @@ +from __future__ import annotations + +import base64 +import binascii +import email.utils +import http +import re +import warnings +from collections.abc import Generator, Sequence +from typing import Any, Callable, cast + +from .datastructures import Headers, MultipleValuesError +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidMessage, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) +from .extensions import Extension, ServerExtensionFactory +from .headers import ( + build_extension, + parse_connection, + parse_extension, + parse_subprotocol, + parse_upgrade, +) +from .http11 import Request, Response +from .imports import lazy_import +from .protocol import CONNECTING, OPEN, SERVER, Protocol, State +from .typing import ( + ConnectionOption, + ExtensionHeader, + LoggerLike, + Origin, + StatusLike, + Subprotocol, + UpgradeProtocol, +) +from .utils import accept_key + + +__all__ = ["ServerProtocol"] + + +class ServerProtocol(Protocol): + """ + Sans-I/O implementation of a WebSocket server connection. + + Args: + origins: Acceptable values of the ``Origin`` header. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket + Hijacking attacks. + extensions: List of supported extensions, in order in which they + should be tried. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It has the same + signature as the :meth:`select_subprotocol` method, including a + :class:`ServerProtocol` instance as first argument. + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; + :obj:`None` disables the limit. + logger: Logger for this connection; + defaults to ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + *, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + state: State = CONNECTING, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, + ) -> None: + super().__init__( + side=SERVER, + state=state, + max_size=max_size, + logger=logger, + ) + self.origins = origins + self.available_extensions = extensions + self.available_subprotocols = subprotocols + if select_subprotocol is not None: + # Bind select_subprotocol then shadow self.select_subprotocol. + # Use setattr to work around https://github.com/python/mypy/issues/2427. + setattr( + self, + "select_subprotocol", + select_subprotocol.__get__(self, self.__class__), + ) + + def accept(self, request: Request) -> Response: + """ + Create a handshake response to accept the connection. + + If the handshake request is valid and the handshake successful, + :meth:`accept` returns an HTTP response with status code 101. + + Else, it returns an HTTP response with another status code. This rejects + the connection, like :meth:`reject` would. + + You must send the handshake response with :meth:`send_response`. + + You may modify the response before sending it, typically by adding HTTP + headers. + + Args: + request: WebSocket handshake request received from the client. + + Returns: + WebSocket handshake response or HTTP response to send to the client. + + """ + try: + ( + accept_header, + extensions_header, + protocol_header, + ) = self.process_request(request) + except InvalidOrigin as exc: + request._exception = exc + self.handshake_exc = exc + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) + return self.reject( + http.HTTPStatus.FORBIDDEN, + f"Failed to open a WebSocket connection: {exc}.\n", + ) + except InvalidUpgrade as exc: + request._exception = exc + self.handshake_exc = exc + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) + response = self.reject( + http.HTTPStatus.UPGRADE_REQUIRED, + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ), + ) + response.headers["Upgrade"] = "websocket" + return response + except InvalidHandshake as exc: + request._exception = exc + self.handshake_exc = exc + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" + return self.reject( + http.HTTPStatus.BAD_REQUEST, + f"Failed to open a WebSocket connection: {exc_str}.\n", + ) + except Exception as exc: + # Handle exceptions raised by user-provided select_subprotocol and + # unexpected errors. + request._exception = exc + self.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + return self.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + headers = Headers() + + headers["Date"] = email.utils.formatdate(usegmt=True) + + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept_header + + if extensions_header is not None: + headers["Sec-WebSocket-Extensions"] = extensions_header + + if protocol_header is not None: + headers["Sec-WebSocket-Protocol"] = protocol_header + + return Response(101, "Switching Protocols", headers) + + def process_request( + self, + request: Request, + ) -> tuple[str, str | None, str | None]: + """ + Check a handshake request and negotiate extensions and subprotocol. + + This function doesn't verify that the request is an HTTP/1.1 or higher + GET request and doesn't check the ``Host`` header. These controls are + usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + Args: + request: WebSocket handshake request received from the client. + + Returns: + ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and + ``Sec-WebSocket-Protocol`` headers for the handshake response. + + Raises: + InvalidHandshake: If the handshake request is invalid; + then the server must return 400 Bad Request error. + + """ + headers = request.headers + + connection: list[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade( + "Connection", ", ".join(connection) if connection else None + ) + + upgrade: list[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + + try: + key = headers["Sec-WebSocket-Key"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None + + try: + raw_key = base64.b64decode(key.encode(), validate=True) + except binascii.Error as exc: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc + if len(raw_key) != 16: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) + + try: + version = headers["Sec-WebSocket-Version"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None + + if version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", version) + + accept_header = accept_key(key) + + self.origin = self.process_origin(headers) + + extensions_header, self.extensions = self.process_extensions(headers) + + protocol_header = self.subprotocol = self.process_subprotocol(headers) + + return ( + accept_header, + extensions_header, + protocol_header, + ) + + def process_origin(self, headers: Headers) -> Origin | None: + """ + Handle the Origin HTTP request header. + + Args: + headers: WebSocket handshake request headers. + + Returns: + origin, if it is acceptable. + + Raises: + InvalidHandshake: If the Origin header is invalid. + InvalidOrigin: If the origin isn't acceptable. + + """ + # "The user agent MUST NOT include more than one Origin header field" + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. + try: + origin = headers.get("Origin") + except MultipleValuesError: + raise InvalidHeader("Origin", "multiple values") from None + if origin is not None: + origin = cast(Origin, origin) + if self.origins is not None: + for origin_or_regex in self.origins: + if origin_or_regex == origin or ( + isinstance(origin_or_regex, re.Pattern) + and origin is not None + and origin_or_regex.fullmatch(origin) is not None + ): + break + else: + raise InvalidOrigin(origin) + return origin + + def process_extensions( + self, + headers: Headers, + ) -> tuple[str | None, list[Extension]]: + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. + + Per :rfc:`6455`, negotiation rules are defined by the specification of + each extension. + + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. + + If several variants of the same extension are proposed by the client, + it may be accepted several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + Args: + headers: WebSocket handshake request headers. + + Returns: + ``Sec-WebSocket-Extensions`` HTTP response header and list of + accepted extensions. + + Raises: + InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. + + """ + response_header_value: str | None = None + + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values and self.available_extensions: + parsed_header_values: list[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, request_params in parsed_header_values: + for ext_factory in self.available_extensions: + # Skip non-matching extensions based on their name. + if ext_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + extension_headers.append((name, response_params)) + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the client sent. The extension is declined. + + # Serialize extension header. + if extension_headers: + response_header_value = build_extension(extension_headers) + + return response_header_value, accepted_extensions + + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + Args: + headers: WebSocket handshake request headers. + + Returns: + Subprotocol, if one was selected; this is also the value of the + ``Sec-WebSocket-Protocol`` response header. + + Raises: + InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid. + + """ + subprotocols: Sequence[Subprotocol] = sum( + [ + parse_subprotocol(header_value) + for header_value in headers.get_all("Sec-WebSocket-Protocol") + ], + [], + ) + + return self.select_subprotocol(subprotocols) + + def select_subprotocol( + self, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + """ + Pick a subprotocol among those offered by the client. + + If several subprotocols are supported by both the client and the server, + pick the first one in the list declared the server. + + If the server doesn't support any subprotocols, continue without a + subprotocol, regardless of what the client offers. + + If the server supports at least one subprotocol and the client doesn't + offer any, abort the handshake with an HTTP 400 error. + + You provide a ``select_subprotocol`` argument to :class:`ServerProtocol` + to override this logic. For example, you could accept the connection + even if client doesn't offer a subprotocol, rather than reject it. + + Here's how to negotiate the ``chat`` subprotocol if the client supports + it and continue without a subprotocol otherwise:: + + def select_subprotocol(protocol, subprotocols): + if "chat" in subprotocols: + return "chat" + + Args: + subprotocols: List of subprotocols offered by the client. + + Returns: + Selected subprotocol, if a common subprotocol was found. + + :obj:`None` to continue without a subprotocol. + + Raises: + NegotiationError: Custom implementations may raise this exception + to abort the handshake with an HTTP 400 error. + + """ + # Server doesn't offer any subprotocols. + if not self.available_subprotocols: # None or empty list + return None + + # Server offers at least one subprotocol but client doesn't offer any. + if not subprotocols: + raise NegotiationError("missing subprotocol") + + # Server and client both offer subprotocols. Look for a shared one. + proposed_subprotocols = set(subprotocols) + for subprotocol in self.available_subprotocols: + if subprotocol in proposed_subprotocols: + return subprotocol + + # No common subprotocol was found. + raise NegotiationError( + "invalid subprotocol; expected one of " + + ", ".join(self.available_subprotocols) + ) + + def reject(self, status: StatusLike, text: str) -> Response: + """ + Create a handshake response to reject the connection. + + A short plain text response is the best fallback when failing to + establish a WebSocket connection. + + You must send the handshake response with :meth:`send_response`. + + You may modify the response before sending it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + # If status is an int instead of an HTTPStatus, fix it automatically. + status = http.HTTPStatus(status) + body = text.encode() + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "text/plain; charset=utf-8"), + ] + ) + return Response(status.value, status.phrase, headers, body) + + def send_response(self, response: Response) -> None: + """ + Send a handshake response to the client. + + Args: + response: WebSocket handshake response event to send. + + """ + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("> HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if response.body: + self.logger.debug("> [body] (%d bytes)", len(response.body)) + + self.writes.append(response.serialize()) + + if response.status_code == 101: + assert self.state is CONNECTING + self.state = OPEN + self.logger.info("connection open") + + else: + self.logger.info( + "connection rejected (%d %s)", + response.status_code, + response.reason_phrase, + ) + + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + + def parse(self) -> Generator[None]: + if self.state is CONNECTING: + try: + request = yield from Request.parse( + self.reader.read_line, + ) + except Exception as exc: + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP request" + ) + self.handshake_exc.__cause__ = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield + + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.events.append(request) + + yield from super().parse() + + +class ServerConnection(ServerProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( # deprecated in 11.0 - 2023-04-02 + "ServerConnection was renamed to ServerProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + }, +) diff --git a/.venv/lib/python3.11/site-packages/websockets/speedups.c b/.venv/lib/python3.11/site-packages/websockets/speedups.c new file mode 100644 index 0000000000000000000000000000000000000000..cb10dedb83fd160b340cdb72f2d11ba94424a807 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/speedups.c @@ -0,0 +1,222 @@ +/* C implementation of performance sensitive functions. */ + +#define PY_SSIZE_T_CLEAN +#include +#include /* uint8_t, uint32_t, uint64_t */ + +#if __ARM_NEON +#include +#elif __SSE2__ +#include +#endif + +static const Py_ssize_t MASK_LEN = 4; + +/* Similar to PyBytes_AsStringAndSize, but accepts more types */ + +static int +_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length) +{ + // This supports bytes, bytearrays, and memoryview objects, + // which are common data structures for handling byte streams. + // If *tmp isn't NULL, the caller gets a new reference. + if (PyBytes_Check(obj)) + { + *tmp = NULL; + *buffer = PyBytes_AS_STRING(obj); + *length = PyBytes_GET_SIZE(obj); + } + else if (PyByteArray_Check(obj)) + { + *tmp = NULL; + *buffer = PyByteArray_AS_STRING(obj); + *length = PyByteArray_GET_SIZE(obj); + } + else if (PyMemoryView_Check(obj)) + { + *tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C'); + if (*tmp == NULL) + { + return -1; + } + Py_buffer *mv_buf; + mv_buf = PyMemoryView_GET_BUFFER(*tmp); + *buffer = mv_buf->buf; + *length = mv_buf->len; + } + else + { + PyErr_Format( + PyExc_TypeError, + "expected a bytes-like object, %.200s found", + Py_TYPE(obj)->tp_name); + return -1; + } + + return 0; +} + +/* C implementation of websockets.utils.apply_mask */ + +static PyObject * +apply_mask(PyObject *self, PyObject *args, PyObject *kwds) +{ + + // In order to support various bytes-like types, accept any Python object. + + static char *kwlist[] = {"data", "mask", NULL}; + PyObject *input_obj; + PyObject *mask_obj; + + // A pointer to a char * + length will be extracted from the data and mask + // arguments, possibly via a Py_buffer. + + PyObject *input_tmp = NULL; + char *input; + Py_ssize_t input_len; + PyObject *mask_tmp = NULL; + char *mask; + Py_ssize_t mask_len; + + // Initialize a PyBytesObject then get a pointer to the underlying char * + // in order to avoid an extra memory copy in PyBytes_FromStringAndSize. + + PyObject *result = NULL; + char *output; + + // Other variables. + + Py_ssize_t i = 0; + + // Parse inputs. + + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "OO", kwlist, &input_obj, &mask_obj)) + { + goto exit; + } + + if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1) + { + goto exit; + } + + if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1) + { + goto exit; + } + + if (mask_len != MASK_LEN) + { + PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes"); + goto exit; + } + + // Create output. + + result = PyBytes_FromStringAndSize(NULL, input_len); + if (result == NULL) + { + goto exit; + } + + // Since we just created result, we don't need error checks. + output = PyBytes_AS_STRING(result); + + // Perform the masking operation. + + // Apparently GCC cannot figure out the following optimizations by itself. + + // We need a new scope for MSVC 2010 (non C99 friendly) + { +#if __ARM_NEON + + // With NEON support, XOR by blocks of 16 bytes = 128 bits. + + Py_ssize_t input_len_128 = input_len & ~15; + uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask)); + + for (; i < input_len_128; i += 16) + { + uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i)); + uint8x16_t out_128 = veorq_u8(in_128, mask_128); + vst1q_u8((uint8_t *)(output + i), out_128); + } + +#elif __SSE2__ + + // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. + + // Since we cannot control the 16-bytes alignment of input and output + // buffers, we rely on loadu/storeu rather than load/store. + + Py_ssize_t input_len_128 = input_len & ~15; + __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask); + + for (; i < input_len_128; i += 16) + { + __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i)); + __m128i out_128 = _mm_xor_si128(in_128, mask_128); + _mm_storeu_si128((__m128i *)(output + i), out_128); + } + +#else + + // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits. + + // We assume the memory allocator aligns everything on 8 bytes boundaries. + + Py_ssize_t input_len_64 = input_len & ~7; + uint32_t mask_32 = *(uint32_t *)mask; + uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32; + + for (; i < input_len_64; i += 8) + { + *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64; + } + +#endif + } + + // XOR the remainder of the input byte by byte. + + for (; i < input_len; i++) + { + output[i] = input[i] ^ mask[i & (MASK_LEN - 1)]; + } + +exit: + Py_XDECREF(input_tmp); + Py_XDECREF(mask_tmp); + return result; + +} + +static PyMethodDef speedups_methods[] = { + { + "apply_mask", + (PyCFunction)apply_mask, + METH_VARARGS | METH_KEYWORDS, + "Apply masking to the data of a WebSocket message.", + }, + {NULL, NULL, 0, NULL}, /* Sentinel */ +}; + +static struct PyModuleDef speedups_module = { + PyModuleDef_HEAD_INIT, + "websocket.speedups", /* m_name */ + "C implementation of performance sensitive functions.", + /* m_doc */ + -1, /* m_size */ + speedups_methods, /* m_methods */ + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC +PyInit_speedups(void) +{ + return PyModule_Create(&speedups_module); +} diff --git a/.venv/lib/python3.11/site-packages/websockets/speedups.cpython-311-x86_64-linux-gnu.so b/.venv/lib/python3.11/site-packages/websockets/speedups.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..7c5720a293c85c51a1bdc560cfda6bd4536d761c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/websockets/speedups.cpython-311-x86_64-linux-gnu.so differ diff --git a/.venv/lib/python3.11/site-packages/websockets/speedups.pyi b/.venv/lib/python3.11/site-packages/websockets/speedups.pyi new file mode 100644 index 0000000000000000000000000000000000000000..821438a064e6ad32154eb6536c975f70d4c35d05 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/speedups.pyi @@ -0,0 +1 @@ +def apply_mask(data: bytes, mask: bytes) -> bytes: ... diff --git a/.venv/lib/python3.11/site-packages/websockets/streams.py b/.venv/lib/python3.11/site-packages/websockets/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..f52e6193aa979564dab68058835ff0ca86b9ca38 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/streams.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from collections.abc import Generator + + +class StreamReader: + """ + Generator-based stream reader. + + This class doesn't support concurrent calls to :meth:`read_line`, + :meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are + serialized. + + """ + + def __init__(self) -> None: + self.buffer = bytearray() + self.eof = False + + def read_line(self, m: int) -> Generator[None, None, bytes]: + """ + Read a LF-terminated line from the stream. + + This is a generator-based coroutine. + + The return value includes the LF character. + + Args: + m: Maximum number bytes to read; this is a security limit. + + Raises: + EOFError: If the stream ends without a LF. + RuntimeError: If the stream ends in more than ``m`` bytes. + + """ + n = 0 # number of bytes to read + p = 0 # number of bytes without a newline + while True: + n = self.buffer.find(b"\n", p) + 1 + if n > 0: + break + p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") + if self.eof: + raise EOFError(f"stream ends after {p} bytes, before end of line") + yield + if n > m: + raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes") + r = self.buffer[:n] + del self.buffer[:n] + return r + + def read_exact(self, n: int) -> Generator[None, None, bytes]: + """ + Read a given number of bytes from the stream. + + This is a generator-based coroutine. + + Args: + n: How many bytes to read. + + Raises: + EOFError: If the stream ends in less than ``n`` bytes. + + """ + assert n >= 0 + while len(self.buffer) < n: + if self.eof: + p = len(self.buffer) + raise EOFError(f"stream ends after {p} bytes, expected {n} bytes") + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def read_to_eof(self, m: int) -> Generator[None, None, bytes]: + """ + Read all bytes from the stream. + + This is a generator-based coroutine. + + Args: + m: Maximum number bytes to read; this is a security limit. + + Raises: + RuntimeError: If the stream ends in more than ``m`` bytes. + + """ + while not self.eof: + p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") + yield + r = self.buffer[:] + del self.buffer[:] + return r + + def at_eof(self) -> Generator[None, None, bool]: + """ + Tell whether the stream has ended and all data was read. + + This is a generator-based coroutine. + + """ + while True: + if self.buffer: + return False + if self.eof: + return True + # When all data was read but the stream hasn't ended, we can't + # tell if until either feed_data() or feed_eof() is called. + yield + + def feed_data(self, data: bytes) -> None: + """ + Write data to the stream. + + :meth:`feed_data` cannot be called after :meth:`feed_eof`. + + Args: + data: Data to write. + + Raises: + EOFError: If the stream has ended. + + """ + if self.eof: + raise EOFError("stream ended") + self.buffer += data + + def feed_eof(self) -> None: + """ + End the stream. + + :meth:`feed_eof` cannot be called more than once. + + Raises: + EOFError: If the stream has ended. + + """ + if self.eof: + raise EOFError("stream ended") + self.eof = True + + def discard(self) -> None: + """ + Discard all buffered data, but don't end the stream. + + """ + del self.buffer[:] diff --git a/.venv/lib/python3.11/site-packages/websockets/typing.py b/.venv/lib/python3.11/site-packages/websockets/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..0a37141c6cb0c6b0c940bd3b6ea20e0855c08c68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/typing.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import http +import logging +import typing +from typing import Any, NewType, Optional, Union + + +__all__ = [ + "Data", + "LoggerLike", + "StatusLike", + "Origin", + "Subprotocol", + "ExtensionName", + "ExtensionParameter", +] + + +# Public types used in the signature of public APIs + +# Change to str | bytes when dropping Python < 3.10. +Data = Union[str, bytes] +"""Types supported in a WebSocket message: +:class:`str` for a Text_ frame, :class:`bytes` for a Binary_. + +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 +.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + +""" + + +# Change to logging.Logger | ... when dropping Python < 3.10. +if typing.TYPE_CHECKING: + LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] + """Types accepted where a :class:`~logging.Logger` is expected.""" +else: # remove this branch when dropping support for Python < 3.11 + LoggerLike = Union[logging.Logger, logging.LoggerAdapter] + """Types accepted where a :class:`~logging.Logger` is expected.""" + + +# Change to http.HTTPStatus | int when dropping Python < 3.10. +StatusLike = Union[http.HTTPStatus, int] +""" +Types accepted where an :class:`~http.HTTPStatus` is expected.""" + + +Origin = NewType("Origin", str) +"""Value of a ``Origin`` header.""" + + +Subprotocol = NewType("Subprotocol", str) +"""Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" + + +ExtensionName = NewType("ExtensionName", str) +"""Name of a WebSocket extension.""" + +# Change to tuple[str, str | None] when dropping Python < 3.10. +ExtensionParameter = tuple[str, Optional[str]] +"""Parameter of a WebSocket extension.""" + + +# Private types + +ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] +"""Extension in a ``Sec-WebSocket-Extensions`` header.""" + + +ConnectionOption = NewType("ConnectionOption", str) +"""Connection option in a ``Connection`` header.""" + + +UpgradeProtocol = NewType("UpgradeProtocol", str) +"""Upgrade protocol in an ``Upgrade`` header.""" diff --git a/.venv/lib/python3.11/site-packages/websockets/uri.py b/.venv/lib/python3.11/site-packages/websockets/uri.py new file mode 100644 index 0000000000000000000000000000000000000000..16bb3f1c1b206a04e51f087bfaf434b26b5e8efa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/uri.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import dataclasses +import urllib.parse + +from .exceptions import InvalidURI + + +__all__ = ["parse_uri", "WebSocketURI"] + + +@dataclasses.dataclass +class WebSocketURI: + """ + WebSocket URI. + + Attributes: + secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. + host: Normalized to lower case. + port: Always set even if it's the default. + path: May be empty. + query: May be empty if the URI doesn't include a query component. + username: Available when the URI contains `User Information`_. + password: Available when the URI contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + secure: bool + host: str + port: int + path: str + query: str + username: str | None = None + password: str | None = None + + @property + def resource_name(self) -> str: + if self.path: + resource_name = self.path + else: + resource_name = "/" + if self.query: + resource_name += "?" + self.query + return resource_name + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +# All characters from the gen-delims and sub-delims sets in RFC 3987. +DELIMS = ":/?#[]@!$&'()*+,;=" + + +def parse_uri(uri: str) -> WebSocketURI: + """ + Parse and validate a WebSocket URI. + + Args: + uri: WebSocket URI. + + Returns: + Parsed WebSocket URI. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + + """ + parsed = urllib.parse.urlparse(uri) + if parsed.scheme not in ["ws", "wss"]: + raise InvalidURI(uri, "scheme isn't ws or wss") + if parsed.hostname is None: + raise InvalidURI(uri, "hostname isn't provided") + if parsed.fragment != "": + raise InvalidURI(uri, "fragment identifier is meaningless") + + secure = parsed.scheme == "wss" + host = parsed.hostname + port = parsed.port or (443 if secure else 80) + path = parsed.path + query = parsed.query + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidURI(uri, "username provided without password") + + try: + uri.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + path = urllib.parse.quote(path, safe=DELIMS) + query = urllib.parse.quote(query, safe=DELIMS) + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return WebSocketURI(secure, host, port, path, query, username, password) diff --git a/.venv/lib/python3.11/site-packages/websockets/version.py b/.venv/lib/python3.11/site-packages/websockets/version.py new file mode 100644 index 0000000000000000000000000000000000000000..4b11b6fe9e12b6ae9dd9092a445ab2cd91855736 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/websockets/version.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import importlib.metadata + + +__all__ = ["tag", "version", "commit"] + + +# ========= =========== =================== +# release development +# ========= =========== =================== +# tag X.Y X.Y (upcoming) +# version X.Y X.Y.dev1+g5678cde +# commit X.Y 5678cde +# ========= =========== =================== + + +# When tagging a release, set `released = True`. +# After tagging a release, set `released = False` and increment `tag`. + +released = True + +tag = version = commit = "14.2" + + +if not released: # pragma: no cover + import pathlib + import re + import subprocess + + def get_version(tag: str) -> str: + # Since setup.py executes the contents of src/websockets/version.py, + # __file__ can point to either of these two files. + file_path = pathlib.Path(__file__) + root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] + + # Read version from package metadata if it is installed. + try: + version = importlib.metadata.version("websockets") + except ImportError: + pass + else: + # Check that this file belongs to the installed package. + files = importlib.metadata.files("websockets") + if files: + version_files = [f for f in files if f.name == file_path.name] + if version_files: + version_file = version_files[0] + if version_file.locate() == file_path: + return version + + # Read version from git if available. + try: + description = subprocess.run( + ["git", "describe", "--dirty", "--tags", "--long"], + capture_output=True, + cwd=root_dir, + timeout=1, + check=True, + text=True, + ).stdout.strip() + # subprocess.run raises FileNotFoundError if git isn't on $PATH. + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + pass + else: + description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" + match = re.fullmatch(description_re, description) + if match is None: + raise ValueError(f"Unexpected git description: {description}") + distance, remainder = match.groups() + remainder = remainder.replace("-", ".") # required by PEP 440 + return f"{tag}.dev{distance}+{remainder}" + + # Avoid crashing if the development version cannot be determined. + return f"{tag}.dev0+gunknown" + + version = get_version(tag) + + def get_commit(tag: str, version: str) -> str: + # Extract commit from version, falling back to tag if not available. + version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" + match = re.fullmatch(version_re, version) + if match is None: + raise ValueError(f"Unexpected version: {version}") + (commit,) = match.groups() + return tag if commit == "unknown" else commit + + commit = get_commit(tag, version)