Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import array | |
| import asyncio | |
| import concurrent.futures | |
| import contextvars | |
| import math | |
| import os | |
| import socket | |
| import sys | |
| import threading | |
| import weakref | |
| from asyncio import ( | |
| AbstractEventLoop, | |
| CancelledError, | |
| all_tasks, | |
| create_task, | |
| current_task, | |
| get_running_loop, | |
| sleep, | |
| ) | |
| from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] | |
| from collections import OrderedDict, deque | |
| from collections.abc import ( | |
| AsyncGenerator, | |
| AsyncIterator, | |
| Awaitable, | |
| Callable, | |
| Collection, | |
| Coroutine, | |
| Iterable, | |
| Sequence, | |
| ) | |
| from concurrent.futures import Future | |
| from contextlib import AbstractContextManager, suppress | |
| from contextvars import Context, copy_context | |
| from dataclasses import dataclass | |
| from functools import partial, wraps | |
| from inspect import ( | |
| CORO_RUNNING, | |
| CORO_SUSPENDED, | |
| getcoroutinestate, | |
| iscoroutine, | |
| ) | |
| from io import IOBase | |
| from os import PathLike | |
| from queue import Queue | |
| from signal import Signals | |
| from socket import AddressFamily, SocketKind | |
| from threading import Thread | |
| from types import CodeType, TracebackType | |
| from typing import ( | |
| IO, | |
| TYPE_CHECKING, | |
| Any, | |
| Optional, | |
| TypeVar, | |
| cast, | |
| ) | |
| from weakref import WeakKeyDictionary | |
| import sniffio | |
| from .. import ( | |
| CapacityLimiterStatistics, | |
| EventStatistics, | |
| LockStatistics, | |
| TaskInfo, | |
| abc, | |
| ) | |
| from .._core._eventloop import claim_worker_thread, threadlocals | |
| from .._core._exceptions import ( | |
| BrokenResourceError, | |
| BusyResourceError, | |
| ClosedResourceError, | |
| EndOfStream, | |
| WouldBlock, | |
| iterate_exceptions, | |
| ) | |
| from .._core._sockets import convert_ipv6_sockaddr | |
| from .._core._streams import create_memory_object_stream | |
| from .._core._synchronization import ( | |
| CapacityLimiter as BaseCapacityLimiter, | |
| ) | |
| from .._core._synchronization import Event as BaseEvent | |
| from .._core._synchronization import Lock as BaseLock | |
| from .._core._synchronization import ( | |
| ResourceGuard, | |
| SemaphoreStatistics, | |
| ) | |
| from .._core._synchronization import Semaphore as BaseSemaphore | |
| from .._core._tasks import CancelScope as BaseCancelScope | |
| from ..abc import ( | |
| AsyncBackend, | |
| IPSockAddrType, | |
| SocketListener, | |
| UDPPacketType, | |
| UNIXDatagramPacketType, | |
| ) | |
| from ..abc._eventloop import StrOrBytesPath | |
| from ..lowlevel import RunVar | |
| from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | |
| if TYPE_CHECKING: | |
| from _typeshed import FileDescriptorLike | |
| else: | |
| FileDescriptorLike = object | |
| if sys.version_info >= (3, 10): | |
| from typing import ParamSpec | |
| else: | |
| from typing_extensions import ParamSpec | |
| if sys.version_info >= (3, 11): | |
| from asyncio import Runner | |
| from typing import TypeVarTuple, Unpack | |
| else: | |
| import contextvars | |
| import enum | |
| import signal | |
| from asyncio import coroutines, events, exceptions, tasks | |
| from exceptiongroup import BaseExceptionGroup | |
| from typing_extensions import TypeVarTuple, Unpack | |
| class _State(enum.Enum): | |
| CREATED = "created" | |
| INITIALIZED = "initialized" | |
| CLOSED = "closed" | |
| class Runner: | |
| # Copied from CPython 3.11 | |
| def __init__( | |
| self, | |
| *, | |
| debug: bool | None = None, | |
| loop_factory: Callable[[], AbstractEventLoop] | None = None, | |
| ): | |
| self._state = _State.CREATED | |
| self._debug = debug | |
| self._loop_factory = loop_factory | |
| self._loop: AbstractEventLoop | None = None | |
| self._context = None | |
| self._interrupt_count = 0 | |
| self._set_event_loop = False | |
| def __enter__(self) -> Runner: | |
| self._lazy_init() | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException], | |
| exc_val: BaseException, | |
| exc_tb: TracebackType, | |
| ) -> None: | |
| self.close() | |
| def close(self) -> None: | |
| """Shutdown and close event loop.""" | |
| if self._state is not _State.INITIALIZED: | |
| return | |
| try: | |
| loop = self._loop | |
| _cancel_all_tasks(loop) | |
| loop.run_until_complete(loop.shutdown_asyncgens()) | |
| if hasattr(loop, "shutdown_default_executor"): | |
| loop.run_until_complete(loop.shutdown_default_executor()) | |
| else: | |
| loop.run_until_complete(_shutdown_default_executor(loop)) | |
| finally: | |
| if self._set_event_loop: | |
| events.set_event_loop(None) | |
| loop.close() | |
| self._loop = None | |
| self._state = _State.CLOSED | |
| def get_loop(self) -> AbstractEventLoop: | |
| """Return embedded event loop.""" | |
| self._lazy_init() | |
| return self._loop | |
| def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval: | |
| """Run a coroutine inside the embedded event loop.""" | |
| if not coroutines.iscoroutine(coro): | |
| raise ValueError(f"a coroutine was expected, got {coro!r}") | |
| if events._get_running_loop() is not None: | |
| # fail fast with short traceback | |
| raise RuntimeError( | |
| "Runner.run() cannot be called from a running event loop" | |
| ) | |
| self._lazy_init() | |
| if context is None: | |
| context = self._context | |
| task = context.run(self._loop.create_task, coro) | |
| if ( | |
| threading.current_thread() is threading.main_thread() | |
| and signal.getsignal(signal.SIGINT) is signal.default_int_handler | |
| ): | |
| sigint_handler = partial(self._on_sigint, main_task=task) | |
| try: | |
| signal.signal(signal.SIGINT, sigint_handler) | |
| except ValueError: | |
| # `signal.signal` may throw if `threading.main_thread` does | |
| # not support signals (e.g. embedded interpreter with signals | |
| # not registered - see gh-91880) | |
| sigint_handler = None | |
| else: | |
| sigint_handler = None | |
| self._interrupt_count = 0 | |
| try: | |
| return self._loop.run_until_complete(task) | |
| except exceptions.CancelledError: | |
| if self._interrupt_count > 0: | |
| uncancel = getattr(task, "uncancel", None) | |
| if uncancel is not None and uncancel() == 0: | |
| raise KeyboardInterrupt() | |
| raise # CancelledError | |
| finally: | |
| if ( | |
| sigint_handler is not None | |
| and signal.getsignal(signal.SIGINT) is sigint_handler | |
| ): | |
| signal.signal(signal.SIGINT, signal.default_int_handler) | |
| def _lazy_init(self) -> None: | |
| if self._state is _State.CLOSED: | |
| raise RuntimeError("Runner is closed") | |
| if self._state is _State.INITIALIZED: | |
| return | |
| if self._loop_factory is None: | |
| self._loop = events.new_event_loop() | |
| if not self._set_event_loop: | |
| # Call set_event_loop only once to avoid calling | |
| # attach_loop multiple times on child watchers | |
| events.set_event_loop(self._loop) | |
| self._set_event_loop = True | |
| else: | |
| self._loop = self._loop_factory() | |
| if self._debug is not None: | |
| self._loop.set_debug(self._debug) | |
| self._context = contextvars.copy_context() | |
| self._state = _State.INITIALIZED | |
| def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None: | |
| self._interrupt_count += 1 | |
| if self._interrupt_count == 1 and not main_task.done(): | |
| main_task.cancel() | |
| # wakeup loop if it is blocked by select() with long timeout | |
| self._loop.call_soon_threadsafe(lambda: None) | |
| return | |
| raise KeyboardInterrupt() | |
| def _cancel_all_tasks(loop: AbstractEventLoop) -> None: | |
| to_cancel = tasks.all_tasks(loop) | |
| if not to_cancel: | |
| return | |
| for task in to_cancel: | |
| task.cancel() | |
| loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True)) | |
| for task in to_cancel: | |
| if task.cancelled(): | |
| continue | |
| if task.exception() is not None: | |
| loop.call_exception_handler( | |
| { | |
| "message": "unhandled exception during asyncio.run() shutdown", | |
| "exception": task.exception(), | |
| "task": task, | |
| } | |
| ) | |
| async def _shutdown_default_executor(loop: AbstractEventLoop) -> None: | |
| """Schedule the shutdown of the default executor.""" | |
| def _do_shutdown(future: asyncio.futures.Future) -> None: | |
| try: | |
| loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined] | |
| loop.call_soon_threadsafe(future.set_result, None) | |
| except Exception as ex: | |
| loop.call_soon_threadsafe(future.set_exception, ex) | |
| loop._executor_shutdown_called = True | |
| if loop._default_executor is None: | |
| return | |
| future = loop.create_future() | |
| thread = threading.Thread(target=_do_shutdown, args=(future,)) | |
| thread.start() | |
| try: | |
| await future | |
| finally: | |
| thread.join() | |
| T_Retval = TypeVar("T_Retval") | |
| T_contra = TypeVar("T_contra", contravariant=True) | |
| PosArgsT = TypeVarTuple("PosArgsT") | |
| P = ParamSpec("P") | |
| _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") | |
| def find_root_task() -> asyncio.Task: | |
| root_task = _root_task.get(None) | |
| if root_task is not None and not root_task.done(): | |
| return root_task | |
| # Look for a task that has been started via run_until_complete() | |
| for task in all_tasks(): | |
| if task._callbacks and not task.done(): | |
| callbacks = [cb for cb, context in task._callbacks] | |
| for cb in callbacks: | |
| if ( | |
| cb is _run_until_complete_cb | |
| or getattr(cb, "__module__", None) == "uvloop.loop" | |
| ): | |
| _root_task.set(task) | |
| return task | |
| # Look up the topmost task in the AnyIO task tree, if possible | |
| task = cast(asyncio.Task, current_task()) | |
| state = _task_states.get(task) | |
| if state: | |
| cancel_scope = state.cancel_scope | |
| while cancel_scope and cancel_scope._parent_scope is not None: | |
| cancel_scope = cancel_scope._parent_scope | |
| if cancel_scope is not None: | |
| return cast(asyncio.Task, cancel_scope._host_task) | |
| return task | |
| def get_callable_name(func: Callable) -> str: | |
| module = getattr(func, "__module__", None) | |
| qualname = getattr(func, "__qualname__", None) | |
| return ".".join([x for x in (module, qualname) if x]) | |
| # | |
| # Event loop | |
| # | |
| _run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary() | |
| def _task_started(task: asyncio.Task) -> bool: | |
| """Return ``True`` if the task has been started and has not finished.""" | |
| # The task coro should never be None here, as we never add finished tasks to the | |
| # task list | |
| coro = task.get_coro() | |
| assert coro is not None | |
| try: | |
| return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) | |
| except AttributeError: | |
| # task coro is async_genenerator_asend https://bugs.python.org/issue37771 | |
| raise Exception(f"Cannot determine if task {task} has started or not") from None | |
| # | |
| # Timeouts and cancellation | |
| # | |
| def is_anyio_cancellation(exc: CancelledError) -> bool: | |
| # Sometimes third party frameworks catch a CancelledError and raise a new one, so as | |
| # a workaround we have to look at the previous ones in __context__ too for a | |
| # matching cancel message | |
| while True: | |
| if ( | |
| exc.args | |
| and isinstance(exc.args[0], str) | |
| and exc.args[0].startswith("Cancelled by cancel scope ") | |
| ): | |
| return True | |
| if isinstance(exc.__context__, CancelledError): | |
| exc = exc.__context__ | |
| continue | |
| return False | |
| class CancelScope(BaseCancelScope): | |
| def __new__( | |
| cls, *, deadline: float = math.inf, shield: bool = False | |
| ) -> CancelScope: | |
| return object.__new__(cls) | |
| def __init__(self, deadline: float = math.inf, shield: bool = False): | |
| self._deadline = deadline | |
| self._shield = shield | |
| self._parent_scope: CancelScope | None = None | |
| self._child_scopes: set[CancelScope] = set() | |
| self._cancel_called = False | |
| self._cancelled_caught = False | |
| self._active = False | |
| self._timeout_handle: asyncio.TimerHandle | None = None | |
| self._cancel_handle: asyncio.Handle | None = None | |
| self._tasks: set[asyncio.Task] = set() | |
| self._host_task: asyncio.Task | None = None | |
| if sys.version_info >= (3, 11): | |
| self._pending_uncancellations: int | None = 0 | |
| else: | |
| self._pending_uncancellations = None | |
| def __enter__(self) -> CancelScope: | |
| if self._active: | |
| raise RuntimeError( | |
| "Each CancelScope may only be used for a single 'with' block" | |
| ) | |
| self._host_task = host_task = cast(asyncio.Task, current_task()) | |
| self._tasks.add(host_task) | |
| try: | |
| task_state = _task_states[host_task] | |
| except KeyError: | |
| task_state = TaskState(None, self) | |
| _task_states[host_task] = task_state | |
| else: | |
| self._parent_scope = task_state.cancel_scope | |
| task_state.cancel_scope = self | |
| if self._parent_scope is not None: | |
| # If using an eager task factory, the parent scope may not even contain | |
| # the host task | |
| self._parent_scope._child_scopes.add(self) | |
| self._parent_scope._tasks.discard(host_task) | |
| self._timeout() | |
| self._active = True | |
| # Start cancelling the host task if the scope was cancelled before entering | |
| if self._cancel_called: | |
| self._deliver_cancellation(self) | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc_val: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> bool: | |
| del exc_tb | |
| if not self._active: | |
| raise RuntimeError("This cancel scope is not active") | |
| if current_task() is not self._host_task: | |
| raise RuntimeError( | |
| "Attempted to exit cancel scope in a different task than it was " | |
| "entered in" | |
| ) | |
| assert self._host_task is not None | |
| host_task_state = _task_states.get(self._host_task) | |
| if host_task_state is None or host_task_state.cancel_scope is not self: | |
| raise RuntimeError( | |
| "Attempted to exit a cancel scope that isn't the current tasks's " | |
| "current cancel scope" | |
| ) | |
| try: | |
| self._active = False | |
| if self._timeout_handle: | |
| self._timeout_handle.cancel() | |
| self._timeout_handle = None | |
| self._tasks.remove(self._host_task) | |
| if self._parent_scope is not None: | |
| self._parent_scope._child_scopes.remove(self) | |
| self._parent_scope._tasks.add(self._host_task) | |
| host_task_state.cancel_scope = self._parent_scope | |
| # Restart the cancellation effort in the closest visible, cancelled parent | |
| # scope if necessary | |
| self._restart_cancellation_in_parent() | |
| # We only swallow the exception iff it was an AnyIO CancelledError, either | |
| # directly as exc_val or inside an exception group and there are no cancelled | |
| # parent cancel scopes visible to us here | |
| if self._cancel_called and not self._parent_cancellation_is_visible_to_us: | |
| # For each level-cancel() call made on the host task, call uncancel() | |
| while self._pending_uncancellations: | |
| self._host_task.uncancel() | |
| self._pending_uncancellations -= 1 | |
| # Update cancelled_caught and check for exceptions we must not swallow | |
| cannot_swallow_exc_val = False | |
| if exc_val is not None: | |
| for exc in iterate_exceptions(exc_val): | |
| if isinstance(exc, CancelledError) and is_anyio_cancellation( | |
| exc | |
| ): | |
| self._cancelled_caught = True | |
| else: | |
| cannot_swallow_exc_val = True | |
| return self._cancelled_caught and not cannot_swallow_exc_val | |
| else: | |
| if self._pending_uncancellations: | |
| assert self._parent_scope is not None | |
| assert self._parent_scope._pending_uncancellations is not None | |
| self._parent_scope._pending_uncancellations += ( | |
| self._pending_uncancellations | |
| ) | |
| self._pending_uncancellations = 0 | |
| return False | |
| finally: | |
| self._host_task = None | |
| del exc_val | |
| def _effectively_cancelled(self) -> bool: | |
| cancel_scope: CancelScope | None = self | |
| while cancel_scope is not None: | |
| if cancel_scope._cancel_called: | |
| return True | |
| if cancel_scope.shield: | |
| return False | |
| cancel_scope = cancel_scope._parent_scope | |
| return False | |
| def _parent_cancellation_is_visible_to_us(self) -> bool: | |
| return ( | |
| self._parent_scope is not None | |
| and not self.shield | |
| and self._parent_scope._effectively_cancelled | |
| ) | |
| def _timeout(self) -> None: | |
| if self._deadline != math.inf: | |
| loop = get_running_loop() | |
| if loop.time() >= self._deadline: | |
| self.cancel() | |
| else: | |
| self._timeout_handle = loop.call_at(self._deadline, self._timeout) | |
| def _deliver_cancellation(self, origin: CancelScope) -> bool: | |
| """ | |
| Deliver cancellation to directly contained tasks and nested cancel scopes. | |
| Schedule another run at the end if we still have tasks eligible for | |
| cancellation. | |
| :param origin: the cancel scope that originated the cancellation | |
| :return: ``True`` if the delivery needs to be retried on the next cycle | |
| """ | |
| should_retry = False | |
| current = current_task() | |
| for task in self._tasks: | |
| should_retry = True | |
| if task._must_cancel: # type: ignore[attr-defined] | |
| continue | |
| # The task is eligible for cancellation if it has started | |
| if task is not current and (task is self._host_task or _task_started(task)): | |
| waiter = task._fut_waiter # type: ignore[attr-defined] | |
| if not isinstance(waiter, asyncio.Future) or not waiter.done(): | |
| task.cancel(f"Cancelled by cancel scope {id(origin):x}") | |
| if ( | |
| task is origin._host_task | |
| and origin._pending_uncancellations is not None | |
| ): | |
| origin._pending_uncancellations += 1 | |
| # Deliver cancellation to child scopes that aren't shielded or running their own | |
| # cancellation callbacks | |
| for scope in self._child_scopes: | |
| if not scope._shield and not scope.cancel_called: | |
| should_retry = scope._deliver_cancellation(origin) or should_retry | |
| # Schedule another callback if there are still tasks left | |
| if origin is self: | |
| if should_retry: | |
| self._cancel_handle = get_running_loop().call_soon( | |
| self._deliver_cancellation, origin | |
| ) | |
| else: | |
| self._cancel_handle = None | |
| return should_retry | |
| def _restart_cancellation_in_parent(self) -> None: | |
| """ | |
| Restart the cancellation effort in the closest directly cancelled parent scope. | |
| """ | |
| scope = self._parent_scope | |
| while scope is not None: | |
| if scope._cancel_called: | |
| if scope._cancel_handle is None: | |
| scope._deliver_cancellation(scope) | |
| break | |
| # No point in looking beyond any shielded scope | |
| if scope._shield: | |
| break | |
| scope = scope._parent_scope | |
| def cancel(self) -> None: | |
| if not self._cancel_called: | |
| if self._timeout_handle: | |
| self._timeout_handle.cancel() | |
| self._timeout_handle = None | |
| self._cancel_called = True | |
| if self._host_task is not None: | |
| self._deliver_cancellation(self) | |
| def deadline(self) -> float: | |
| return self._deadline | |
| def deadline(self, value: float) -> None: | |
| self._deadline = float(value) | |
| if self._timeout_handle is not None: | |
| self._timeout_handle.cancel() | |
| self._timeout_handle = None | |
| if self._active and not self._cancel_called: | |
| self._timeout() | |
| def cancel_called(self) -> bool: | |
| return self._cancel_called | |
| def cancelled_caught(self) -> bool: | |
| return self._cancelled_caught | |
| def shield(self) -> bool: | |
| return self._shield | |
| def shield(self, value: bool) -> None: | |
| if self._shield != value: | |
| self._shield = value | |
| if not value: | |
| self._restart_cancellation_in_parent() | |
| # | |
| # Task states | |
| # | |
| class TaskState: | |
| """ | |
| Encapsulates auxiliary task information that cannot be added to the Task instance | |
| itself because there are no guarantees about its implementation. | |
| """ | |
| __slots__ = "parent_id", "cancel_scope", "__weakref__" | |
| def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): | |
| self.parent_id = parent_id | |
| self.cancel_scope = cancel_scope | |
| _task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() | |
| # | |
| # Task groups | |
| # | |
| class _AsyncioTaskStatus(abc.TaskStatus): | |
| def __init__(self, future: asyncio.Future, parent_id: int): | |
| self._future = future | |
| self._parent_id = parent_id | |
| def started(self, value: T_contra | None = None) -> None: | |
| try: | |
| self._future.set_result(value) | |
| except asyncio.InvalidStateError: | |
| if not self._future.cancelled(): | |
| raise RuntimeError( | |
| "called 'started' twice on the same task status" | |
| ) from None | |
| task = cast(asyncio.Task, current_task()) | |
| _task_states[task].parent_id = self._parent_id | |
| if sys.version_info >= (3, 12): | |
| _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__ | |
| else: | |
| _eager_task_factory_code = None | |
| class TaskGroup(abc.TaskGroup): | |
| def __init__(self) -> None: | |
| self.cancel_scope: CancelScope = CancelScope() | |
| self._active = False | |
| self._exceptions: list[BaseException] = [] | |
| self._tasks: set[asyncio.Task] = set() | |
| self._on_completed_fut: asyncio.Future[None] | None = None | |
| async def __aenter__(self) -> TaskGroup: | |
| self.cancel_scope.__enter__() | |
| self._active = True | |
| return self | |
| async def __aexit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc_val: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> bool | None: | |
| try: | |
| if exc_val is not None: | |
| self.cancel_scope.cancel() | |
| if not isinstance(exc_val, CancelledError): | |
| self._exceptions.append(exc_val) | |
| loop = get_running_loop() | |
| try: | |
| if self._tasks: | |
| with CancelScope() as wait_scope: | |
| while self._tasks: | |
| self._on_completed_fut = loop.create_future() | |
| try: | |
| await self._on_completed_fut | |
| except CancelledError as exc: | |
| # Shield the scope against further cancellation attempts, | |
| # as they're not productive (#695) | |
| wait_scope.shield = True | |
| self.cancel_scope.cancel() | |
| # Set exc_val from the cancellation exception if it was | |
| # previously unset. However, we should not replace a native | |
| # cancellation exception with one raise by a cancel scope. | |
| if exc_val is None or ( | |
| isinstance(exc_val, CancelledError) | |
| and not is_anyio_cancellation(exc) | |
| ): | |
| exc_val = exc | |
| self._on_completed_fut = None | |
| else: | |
| # If there are no child tasks to wait on, run at least one checkpoint | |
| # anyway | |
| await AsyncIOBackend.cancel_shielded_checkpoint() | |
| self._active = False | |
| if self._exceptions: | |
| # The exception that got us here should already have been | |
| # added to self._exceptions so it's ok to break exception | |
| # chaining and avoid adding a "During handling of above..." | |
| # for each nesting level. | |
| raise BaseExceptionGroup( | |
| "unhandled errors in a TaskGroup", self._exceptions | |
| ) from None | |
| elif exc_val: | |
| raise exc_val | |
| except BaseException as exc: | |
| if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__): | |
| return True | |
| raise | |
| return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) | |
| finally: | |
| del exc_val, exc_tb, self._exceptions | |
| def _spawn( | |
| self, | |
| func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], | |
| args: tuple[Unpack[PosArgsT]], | |
| name: object, | |
| task_status_future: asyncio.Future | None = None, | |
| ) -> asyncio.Task: | |
| def task_done(_task: asyncio.Task) -> None: | |
| task_state = _task_states[_task] | |
| assert task_state.cancel_scope is not None | |
| assert _task in task_state.cancel_scope._tasks | |
| task_state.cancel_scope._tasks.remove(_task) | |
| self._tasks.remove(task) | |
| del _task_states[_task] | |
| if self._on_completed_fut is not None and not self._tasks: | |
| try: | |
| self._on_completed_fut.set_result(None) | |
| except asyncio.InvalidStateError: | |
| pass | |
| try: | |
| exc = _task.exception() | |
| except CancelledError as e: | |
| while isinstance(e.__context__, CancelledError): | |
| e = e.__context__ | |
| exc = e | |
| if exc is not None: | |
| # The future can only be in the cancelled state if the host task was | |
| # cancelled, so return immediately instead of adding one more | |
| # CancelledError to the exceptions list | |
| if task_status_future is not None and task_status_future.cancelled(): | |
| return | |
| if task_status_future is None or task_status_future.done(): | |
| if not isinstance(exc, CancelledError): | |
| self._exceptions.append(exc) | |
| if not self.cancel_scope._effectively_cancelled: | |
| self.cancel_scope.cancel() | |
| else: | |
| task_status_future.set_exception(exc) | |
| elif task_status_future is not None and not task_status_future.done(): | |
| task_status_future.set_exception( | |
| RuntimeError("Child exited without calling task_status.started()") | |
| ) | |
| if not self._active: | |
| raise RuntimeError( | |
| "This task group is not active; no new tasks can be started." | |
| ) | |
| kwargs = {} | |
| if task_status_future: | |
| parent_id = id(current_task()) | |
| kwargs["task_status"] = _AsyncioTaskStatus( | |
| task_status_future, id(self.cancel_scope._host_task) | |
| ) | |
| else: | |
| parent_id = id(self.cancel_scope._host_task) | |
| coro = func(*args, **kwargs) | |
| if not iscoroutine(coro): | |
| prefix = f"{func.__module__}." if hasattr(func, "__module__") else "" | |
| raise TypeError( | |
| f"Expected {prefix}{func.__qualname__}() to return a coroutine, but " | |
| f"the return value ({coro!r}) is not a coroutine object" | |
| ) | |
| name = get_callable_name(func) if name is None else str(name) | |
| loop = asyncio.get_running_loop() | |
| if ( | |
| (factory := loop.get_task_factory()) | |
| and getattr(factory, "__code__", None) is _eager_task_factory_code | |
| and (closure := getattr(factory, "__closure__", None)) | |
| ): | |
| custom_task_constructor = closure[0].cell_contents | |
| task = custom_task_constructor(coro, loop=loop, name=name) | |
| else: | |
| task = create_task(coro, name=name) | |
| # Make the spawned task inherit the task group's cancel scope | |
| _task_states[task] = TaskState( | |
| parent_id=parent_id, cancel_scope=self.cancel_scope | |
| ) | |
| self.cancel_scope._tasks.add(task) | |
| self._tasks.add(task) | |
| task.add_done_callback(task_done) | |
| return task | |
| def start_soon( | |
| self, | |
| func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], | |
| *args: Unpack[PosArgsT], | |
| name: object = None, | |
| ) -> None: | |
| self._spawn(func, args, name) | |
| async def start( | |
| self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None | |
| ) -> Any: | |
| future: asyncio.Future = asyncio.Future() | |
| task = self._spawn(func, args, name, future) | |
| # If the task raises an exception after sending a start value without a switch | |
| # point between, the task group is cancelled and this method never proceeds to | |
| # process the completed future. That's why we have to have a shielded cancel | |
| # scope here. | |
| try: | |
| return await future | |
| except CancelledError: | |
| # Cancel the task and wait for it to exit before returning | |
| task.cancel() | |
| with CancelScope(shield=True), suppress(CancelledError): | |
| await task | |
| raise | |
| # | |
| # Threads | |
| # | |
| _Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]] | |
| class WorkerThread(Thread): | |
| MAX_IDLE_TIME = 10 # seconds | |
| def __init__( | |
| self, | |
| root_task: asyncio.Task, | |
| workers: set[WorkerThread], | |
| idle_workers: deque[WorkerThread], | |
| ): | |
| super().__init__(name="AnyIO worker thread") | |
| self.root_task = root_task | |
| self.workers = workers | |
| self.idle_workers = idle_workers | |
| self.loop = root_task._loop | |
| self.queue: Queue[ | |
| tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None | |
| ] = Queue(2) | |
| self.idle_since = AsyncIOBackend.current_time() | |
| self.stopping = False | |
| def _report_result( | |
| self, future: asyncio.Future, result: Any, exc: BaseException | None | |
| ) -> None: | |
| self.idle_since = AsyncIOBackend.current_time() | |
| if not self.stopping: | |
| self.idle_workers.append(self) | |
| if not future.cancelled(): | |
| if exc is not None: | |
| if isinstance(exc, StopIteration): | |
| new_exc = RuntimeError("coroutine raised StopIteration") | |
| new_exc.__cause__ = exc | |
| exc = new_exc | |
| future.set_exception(exc) | |
| else: | |
| future.set_result(result) | |
| def run(self) -> None: | |
| with claim_worker_thread(AsyncIOBackend, self.loop): | |
| while True: | |
| item = self.queue.get() | |
| if item is None: | |
| # Shutdown command received | |
| return | |
| context, func, args, future, cancel_scope = item | |
| if not future.cancelled(): | |
| result = None | |
| exception: BaseException | None = None | |
| threadlocals.current_cancel_scope = cancel_scope | |
| try: | |
| result = context.run(func, *args) | |
| except BaseException as exc: | |
| exception = exc | |
| finally: | |
| del threadlocals.current_cancel_scope | |
| if not self.loop.is_closed(): | |
| self.loop.call_soon_threadsafe( | |
| self._report_result, future, result, exception | |
| ) | |
| del result, exception | |
| self.queue.task_done() | |
| del item, context, func, args, future, cancel_scope | |
| def stop(self, f: asyncio.Task | None = None) -> None: | |
| self.stopping = True | |
| self.queue.put_nowait(None) | |
| self.workers.discard(self) | |
| try: | |
| self.idle_workers.remove(self) | |
| except ValueError: | |
| pass | |
| _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( | |
| "_threadpool_idle_workers" | |
| ) | |
| _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") | |
| class BlockingPortal(abc.BlockingPortal): | |
| def __new__(cls) -> BlockingPortal: | |
| return object.__new__(cls) | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._loop = get_running_loop() | |
| def _spawn_task_from_thread( | |
| self, | |
| func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], | |
| args: tuple[Unpack[PosArgsT]], | |
| kwargs: dict[str, Any], | |
| name: object, | |
| future: Future[T_Retval], | |
| ) -> None: | |
| AsyncIOBackend.run_sync_from_thread( | |
| partial(self._task_group.start_soon, name=name), | |
| (self._call_func, func, args, kwargs, future), | |
| self._loop, | |
| ) | |
| # | |
| # Subprocesses | |
| # | |
| class StreamReaderWrapper(abc.ByteReceiveStream): | |
| _stream: asyncio.StreamReader | |
| async def receive(self, max_bytes: int = 65536) -> bytes: | |
| data = await self._stream.read(max_bytes) | |
| if data: | |
| return data | |
| else: | |
| raise EndOfStream | |
| async def aclose(self) -> None: | |
| self._stream.set_exception(ClosedResourceError()) | |
| await AsyncIOBackend.checkpoint() | |
| class StreamWriterWrapper(abc.ByteSendStream): | |
| _stream: asyncio.StreamWriter | |
| async def send(self, item: bytes) -> None: | |
| self._stream.write(item) | |
| await self._stream.drain() | |
| async def aclose(self) -> None: | |
| self._stream.close() | |
| await AsyncIOBackend.checkpoint() | |
| class Process(abc.Process): | |
| _process: asyncio.subprocess.Process | |
| _stdin: StreamWriterWrapper | None | |
| _stdout: StreamReaderWrapper | None | |
| _stderr: StreamReaderWrapper | None | |
| async def aclose(self) -> None: | |
| with CancelScope(shield=True) as scope: | |
| if self._stdin: | |
| await self._stdin.aclose() | |
| if self._stdout: | |
| await self._stdout.aclose() | |
| if self._stderr: | |
| await self._stderr.aclose() | |
| scope.shield = False | |
| try: | |
| await self.wait() | |
| except BaseException: | |
| scope.shield = True | |
| self.kill() | |
| await self.wait() | |
| raise | |
| async def wait(self) -> int: | |
| return await self._process.wait() | |
| def terminate(self) -> None: | |
| self._process.terminate() | |
| def kill(self) -> None: | |
| self._process.kill() | |
| def send_signal(self, signal: int) -> None: | |
| self._process.send_signal(signal) | |
| def pid(self) -> int: | |
| return self._process.pid | |
| def returncode(self) -> int | None: | |
| return self._process.returncode | |
| def stdin(self) -> abc.ByteSendStream | None: | |
| return self._stdin | |
| def stdout(self) -> abc.ByteReceiveStream | None: | |
| return self._stdout | |
| def stderr(self) -> abc.ByteReceiveStream | None: | |
| return self._stderr | |
| def _forcibly_shutdown_process_pool_on_exit( | |
| workers: set[Process], _task: object | |
| ) -> None: | |
| """ | |
| Forcibly shuts down worker processes belonging to this event loop.""" | |
| child_watcher: asyncio.AbstractChildWatcher | None = None | |
| if sys.version_info < (3, 12): | |
| try: | |
| child_watcher = asyncio.get_event_loop_policy().get_child_watcher() | |
| except NotImplementedError: | |
| pass | |
| # Close as much as possible (w/o async/await) to avoid warnings | |
| for process in workers: | |
| if process.returncode is None: | |
| continue | |
| process._stdin._stream._transport.close() # type: ignore[union-attr] | |
| process._stdout._stream._transport.close() # type: ignore[union-attr] | |
| process._stderr._stream._transport.close() # type: ignore[union-attr] | |
| process.kill() | |
| if child_watcher: | |
| child_watcher.remove_child_handler(process.pid) | |
| async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None: | |
| """ | |
| Shuts down worker processes belonging to this event loop. | |
| NOTE: this only works when the event loop was started using asyncio.run() or | |
| anyio.run(). | |
| """ | |
| process: abc.Process | |
| try: | |
| await sleep(math.inf) | |
| except asyncio.CancelledError: | |
| for process in workers: | |
| if process.returncode is None: | |
| process.kill() | |
| for process in workers: | |
| await process.aclose() | |
| # | |
| # Sockets and networking | |
| # | |
| class StreamProtocol(asyncio.Protocol): | |
| read_queue: deque[bytes] | |
| read_event: asyncio.Event | |
| write_event: asyncio.Event | |
| exception: Exception | None = None | |
| is_at_eof: bool = False | |
| def connection_made(self, transport: asyncio.BaseTransport) -> None: | |
| self.read_queue = deque() | |
| self.read_event = asyncio.Event() | |
| self.write_event = asyncio.Event() | |
| self.write_event.set() | |
| cast(asyncio.Transport, transport).set_write_buffer_limits(0) | |
| def connection_lost(self, exc: Exception | None) -> None: | |
| if exc: | |
| self.exception = BrokenResourceError() | |
| self.exception.__cause__ = exc | |
| self.read_event.set() | |
| self.write_event.set() | |
| def data_received(self, data: bytes) -> None: | |
| # ProactorEventloop sometimes sends bytearray instead of bytes | |
| self.read_queue.append(bytes(data)) | |
| self.read_event.set() | |
| def eof_received(self) -> bool | None: | |
| self.is_at_eof = True | |
| self.read_event.set() | |
| return True | |
| def pause_writing(self) -> None: | |
| self.write_event = asyncio.Event() | |
| def resume_writing(self) -> None: | |
| self.write_event.set() | |
| class DatagramProtocol(asyncio.DatagramProtocol): | |
| read_queue: deque[tuple[bytes, IPSockAddrType]] | |
| read_event: asyncio.Event | |
| write_event: asyncio.Event | |
| exception: Exception | None = None | |
| def connection_made(self, transport: asyncio.BaseTransport) -> None: | |
| self.read_queue = deque(maxlen=100) # arbitrary value | |
| self.read_event = asyncio.Event() | |
| self.write_event = asyncio.Event() | |
| self.write_event.set() | |
| def connection_lost(self, exc: Exception | None) -> None: | |
| self.read_event.set() | |
| self.write_event.set() | |
| def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: | |
| addr = convert_ipv6_sockaddr(addr) | |
| self.read_queue.append((data, addr)) | |
| self.read_event.set() | |
| def error_received(self, exc: Exception) -> None: | |
| self.exception = exc | |
| def pause_writing(self) -> None: | |
| self.write_event.clear() | |
| def resume_writing(self) -> None: | |
| self.write_event.set() | |
| class SocketStream(abc.SocketStream): | |
| def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): | |
| self._transport = transport | |
| self._protocol = protocol | |
| self._receive_guard = ResourceGuard("reading from") | |
| self._send_guard = ResourceGuard("writing to") | |
| self._closed = False | |
| def _raw_socket(self) -> socket.socket: | |
| return self._transport.get_extra_info("socket") | |
| async def receive(self, max_bytes: int = 65536) -> bytes: | |
| with self._receive_guard: | |
| if ( | |
| not self._protocol.read_event.is_set() | |
| and not self._transport.is_closing() | |
| and not self._protocol.is_at_eof | |
| ): | |
| self._transport.resume_reading() | |
| await self._protocol.read_event.wait() | |
| self._transport.pause_reading() | |
| else: | |
| await AsyncIOBackend.checkpoint() | |
| try: | |
| chunk = self._protocol.read_queue.popleft() | |
| except IndexError: | |
| if self._closed: | |
| raise ClosedResourceError from None | |
| elif self._protocol.exception: | |
| raise self._protocol.exception from None | |
| else: | |
| raise EndOfStream from None | |
| if len(chunk) > max_bytes: | |
| # Split the oversized chunk | |
| chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] | |
| self._protocol.read_queue.appendleft(leftover) | |
| # If the read queue is empty, clear the flag so that the next call will | |
| # block until data is available | |
| if not self._protocol.read_queue: | |
| self._protocol.read_event.clear() | |
| return chunk | |
| async def send(self, item: bytes) -> None: | |
| with self._send_guard: | |
| await AsyncIOBackend.checkpoint() | |
| if self._closed: | |
| raise ClosedResourceError | |
| elif self._protocol.exception is not None: | |
| raise self._protocol.exception | |
| try: | |
| self._transport.write(item) | |
| except RuntimeError as exc: | |
| if self._transport.is_closing(): | |
| raise BrokenResourceError from exc | |
| else: | |
| raise | |
| await self._protocol.write_event.wait() | |
| async def send_eof(self) -> None: | |
| try: | |
| self._transport.write_eof() | |
| except OSError: | |
| pass | |
| async def aclose(self) -> None: | |
| if not self._transport.is_closing(): | |
| self._closed = True | |
| try: | |
| self._transport.write_eof() | |
| except OSError: | |
| pass | |
| self._transport.close() | |
| await sleep(0) | |
| self._transport.abort() | |
| class _RawSocketMixin: | |
| _receive_future: asyncio.Future | None = None | |
| _send_future: asyncio.Future | None = None | |
| _closing = False | |
| def __init__(self, raw_socket: socket.socket): | |
| self.__raw_socket = raw_socket | |
| self._receive_guard = ResourceGuard("reading from") | |
| self._send_guard = ResourceGuard("writing to") | |
| def _raw_socket(self) -> socket.socket: | |
| return self.__raw_socket | |
| def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: | |
| def callback(f: object) -> None: | |
| del self._receive_future | |
| loop.remove_reader(self.__raw_socket) | |
| f = self._receive_future = asyncio.Future() | |
| loop.add_reader(self.__raw_socket, f.set_result, None) | |
| f.add_done_callback(callback) | |
| return f | |
| def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: | |
| def callback(f: object) -> None: | |
| del self._send_future | |
| loop.remove_writer(self.__raw_socket) | |
| f = self._send_future = asyncio.Future() | |
| loop.add_writer(self.__raw_socket, f.set_result, None) | |
| f.add_done_callback(callback) | |
| return f | |
| async def aclose(self) -> None: | |
| if not self._closing: | |
| self._closing = True | |
| if self.__raw_socket.fileno() != -1: | |
| self.__raw_socket.close() | |
| if self._receive_future: | |
| self._receive_future.set_result(None) | |
| if self._send_future: | |
| self._send_future.set_result(None) | |
| class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream): | |
| async def send_eof(self) -> None: | |
| with self._send_guard: | |
| self._raw_socket.shutdown(socket.SHUT_WR) | |
| async def receive(self, max_bytes: int = 65536) -> bytes: | |
| loop = get_running_loop() | |
| await AsyncIOBackend.checkpoint() | |
| with self._receive_guard: | |
| while True: | |
| try: | |
| data = self._raw_socket.recv(max_bytes) | |
| except BlockingIOError: | |
| await self._wait_until_readable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| if not data: | |
| raise EndOfStream | |
| return data | |
| async def send(self, item: bytes) -> None: | |
| loop = get_running_loop() | |
| await AsyncIOBackend.checkpoint() | |
| with self._send_guard: | |
| view = memoryview(item) | |
| while view: | |
| try: | |
| bytes_sent = self._raw_socket.send(view) | |
| except BlockingIOError: | |
| await self._wait_until_writable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| view = view[bytes_sent:] | |
| async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: | |
| if not isinstance(msglen, int) or msglen < 0: | |
| raise ValueError("msglen must be a non-negative integer") | |
| if not isinstance(maxfds, int) or maxfds < 1: | |
| raise ValueError("maxfds must be a positive integer") | |
| loop = get_running_loop() | |
| fds = array.array("i") | |
| await AsyncIOBackend.checkpoint() | |
| with self._receive_guard: | |
| while True: | |
| try: | |
| message, ancdata, flags, addr = self._raw_socket.recvmsg( | |
| msglen, socket.CMSG_LEN(maxfds * fds.itemsize) | |
| ) | |
| except BlockingIOError: | |
| await self._wait_until_readable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| if not message and not ancdata: | |
| raise EndOfStream | |
| break | |
| for cmsg_level, cmsg_type, cmsg_data in ancdata: | |
| if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: | |
| raise RuntimeError( | |
| f"Received unexpected ancillary data; message = {message!r}, " | |
| f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" | |
| ) | |
| fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) | |
| return message, list(fds) | |
| async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: | |
| if not message: | |
| raise ValueError("message must not be empty") | |
| if not fds: | |
| raise ValueError("fds must not be empty") | |
| loop = get_running_loop() | |
| filenos: list[int] = [] | |
| for fd in fds: | |
| if isinstance(fd, int): | |
| filenos.append(fd) | |
| elif isinstance(fd, IOBase): | |
| filenos.append(fd.fileno()) | |
| fdarray = array.array("i", filenos) | |
| await AsyncIOBackend.checkpoint() | |
| with self._send_guard: | |
| while True: | |
| try: | |
| # The ignore can be removed after mypy picks up | |
| # https://github.com/python/typeshed/pull/5545 | |
| self._raw_socket.sendmsg( | |
| [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] | |
| ) | |
| break | |
| except BlockingIOError: | |
| await self._wait_until_writable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| class TCPSocketListener(abc.SocketListener): | |
| _accept_scope: CancelScope | None = None | |
| _closed = False | |
| def __init__(self, raw_socket: socket.socket): | |
| self.__raw_socket = raw_socket | |
| self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) | |
| self._accept_guard = ResourceGuard("accepting connections from") | |
| def _raw_socket(self) -> socket.socket: | |
| return self.__raw_socket | |
| async def accept(self) -> abc.SocketStream: | |
| if self._closed: | |
| raise ClosedResourceError | |
| with self._accept_guard: | |
| await AsyncIOBackend.checkpoint() | |
| with CancelScope() as self._accept_scope: | |
| try: | |
| client_sock, _addr = await self._loop.sock_accept(self._raw_socket) | |
| except asyncio.CancelledError: | |
| # Workaround for https://bugs.python.org/issue41317 | |
| try: | |
| self._loop.remove_reader(self._raw_socket) | |
| except (ValueError, NotImplementedError): | |
| pass | |
| if self._closed: | |
| raise ClosedResourceError from None | |
| raise | |
| finally: | |
| self._accept_scope = None | |
| client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
| transport, protocol = await self._loop.connect_accepted_socket( | |
| StreamProtocol, client_sock | |
| ) | |
| return SocketStream(transport, protocol) | |
| async def aclose(self) -> None: | |
| if self._closed: | |
| return | |
| self._closed = True | |
| if self._accept_scope: | |
| # Workaround for https://bugs.python.org/issue41317 | |
| try: | |
| self._loop.remove_reader(self._raw_socket) | |
| except (ValueError, NotImplementedError): | |
| pass | |
| self._accept_scope.cancel() | |
| await sleep(0) | |
| self._raw_socket.close() | |
| class UNIXSocketListener(abc.SocketListener): | |
| def __init__(self, raw_socket: socket.socket): | |
| self.__raw_socket = raw_socket | |
| self._loop = get_running_loop() | |
| self._accept_guard = ResourceGuard("accepting connections from") | |
| self._closed = False | |
| async def accept(self) -> abc.SocketStream: | |
| await AsyncIOBackend.checkpoint() | |
| with self._accept_guard: | |
| while True: | |
| try: | |
| client_sock, _ = self.__raw_socket.accept() | |
| client_sock.setblocking(False) | |
| return UNIXSocketStream(client_sock) | |
| except BlockingIOError: | |
| f: asyncio.Future = asyncio.Future() | |
| self._loop.add_reader(self.__raw_socket, f.set_result, None) | |
| f.add_done_callback( | |
| lambda _: self._loop.remove_reader(self.__raw_socket) | |
| ) | |
| await f | |
| except OSError as exc: | |
| if self._closed: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| async def aclose(self) -> None: | |
| self._closed = True | |
| self.__raw_socket.close() | |
| def _raw_socket(self) -> socket.socket: | |
| return self.__raw_socket | |
| class UDPSocket(abc.UDPSocket): | |
| def __init__( | |
| self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol | |
| ): | |
| self._transport = transport | |
| self._protocol = protocol | |
| self._receive_guard = ResourceGuard("reading from") | |
| self._send_guard = ResourceGuard("writing to") | |
| self._closed = False | |
| def _raw_socket(self) -> socket.socket: | |
| return self._transport.get_extra_info("socket") | |
| async def aclose(self) -> None: | |
| if not self._transport.is_closing(): | |
| self._closed = True | |
| self._transport.close() | |
| async def receive(self) -> tuple[bytes, IPSockAddrType]: | |
| with self._receive_guard: | |
| await AsyncIOBackend.checkpoint() | |
| # If the buffer is empty, ask for more data | |
| if not self._protocol.read_queue and not self._transport.is_closing(): | |
| self._protocol.read_event.clear() | |
| await self._protocol.read_event.wait() | |
| try: | |
| return self._protocol.read_queue.popleft() | |
| except IndexError: | |
| if self._closed: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from None | |
| async def send(self, item: UDPPacketType) -> None: | |
| with self._send_guard: | |
| await AsyncIOBackend.checkpoint() | |
| await self._protocol.write_event.wait() | |
| if self._closed: | |
| raise ClosedResourceError | |
| elif self._transport.is_closing(): | |
| raise BrokenResourceError | |
| else: | |
| self._transport.sendto(*item) | |
| class ConnectedUDPSocket(abc.ConnectedUDPSocket): | |
| def __init__( | |
| self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol | |
| ): | |
| self._transport = transport | |
| self._protocol = protocol | |
| self._receive_guard = ResourceGuard("reading from") | |
| self._send_guard = ResourceGuard("writing to") | |
| self._closed = False | |
| def _raw_socket(self) -> socket.socket: | |
| return self._transport.get_extra_info("socket") | |
| async def aclose(self) -> None: | |
| if not self._transport.is_closing(): | |
| self._closed = True | |
| self._transport.close() | |
| async def receive(self) -> bytes: | |
| with self._receive_guard: | |
| await AsyncIOBackend.checkpoint() | |
| # If the buffer is empty, ask for more data | |
| if not self._protocol.read_queue and not self._transport.is_closing(): | |
| self._protocol.read_event.clear() | |
| await self._protocol.read_event.wait() | |
| try: | |
| packet = self._protocol.read_queue.popleft() | |
| except IndexError: | |
| if self._closed: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from None | |
| return packet[0] | |
| async def send(self, item: bytes) -> None: | |
| with self._send_guard: | |
| await AsyncIOBackend.checkpoint() | |
| await self._protocol.write_event.wait() | |
| if self._closed: | |
| raise ClosedResourceError | |
| elif self._transport.is_closing(): | |
| raise BrokenResourceError | |
| else: | |
| self._transport.sendto(item) | |
| class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket): | |
| async def receive(self) -> UNIXDatagramPacketType: | |
| loop = get_running_loop() | |
| await AsyncIOBackend.checkpoint() | |
| with self._receive_guard: | |
| while True: | |
| try: | |
| data = self._raw_socket.recvfrom(65536) | |
| except BlockingIOError: | |
| await self._wait_until_readable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| return data | |
| async def send(self, item: UNIXDatagramPacketType) -> None: | |
| loop = get_running_loop() | |
| await AsyncIOBackend.checkpoint() | |
| with self._send_guard: | |
| while True: | |
| try: | |
| self._raw_socket.sendto(*item) | |
| except BlockingIOError: | |
| await self._wait_until_writable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| return | |
| class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket): | |
| async def receive(self) -> bytes: | |
| loop = get_running_loop() | |
| await AsyncIOBackend.checkpoint() | |
| with self._receive_guard: | |
| while True: | |
| try: | |
| data = self._raw_socket.recv(65536) | |
| except BlockingIOError: | |
| await self._wait_until_readable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| return data | |
| async def send(self, item: bytes) -> None: | |
| loop = get_running_loop() | |
| await AsyncIOBackend.checkpoint() | |
| with self._send_guard: | |
| while True: | |
| try: | |
| self._raw_socket.send(item) | |
| except BlockingIOError: | |
| await self._wait_until_writable(loop) | |
| except OSError as exc: | |
| if self._closing: | |
| raise ClosedResourceError from None | |
| else: | |
| raise BrokenResourceError from exc | |
| else: | |
| return | |
| _read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events") | |
| _write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events") | |
| # | |
| # Synchronization | |
| # | |
| class Event(BaseEvent): | |
| def __new__(cls) -> Event: | |
| return object.__new__(cls) | |
| def __init__(self) -> None: | |
| self._event = asyncio.Event() | |
| def set(self) -> None: | |
| self._event.set() | |
| def is_set(self) -> bool: | |
| return self._event.is_set() | |
| async def wait(self) -> None: | |
| if self.is_set(): | |
| await AsyncIOBackend.checkpoint() | |
| else: | |
| await self._event.wait() | |
| def statistics(self) -> EventStatistics: | |
| return EventStatistics(len(self._event._waiters)) | |
| class Lock(BaseLock): | |
| def __new__(cls, *, fast_acquire: bool = False) -> Lock: | |
| return object.__new__(cls) | |
| def __init__(self, *, fast_acquire: bool = False) -> None: | |
| self._fast_acquire = fast_acquire | |
| self._owner_task: asyncio.Task | None = None | |
| self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque() | |
| async def acquire(self) -> None: | |
| task = cast(asyncio.Task, current_task()) | |
| if self._owner_task is None and not self._waiters: | |
| await AsyncIOBackend.checkpoint_if_cancelled() | |
| self._owner_task = task | |
| # Unless on the "fast path", yield control of the event loop so that other | |
| # tasks can run too | |
| if not self._fast_acquire: | |
| try: | |
| await AsyncIOBackend.cancel_shielded_checkpoint() | |
| except CancelledError: | |
| self.release() | |
| raise | |
| return | |
| if self._owner_task == task: | |
| raise RuntimeError("Attempted to acquire an already held Lock") | |
| fut: asyncio.Future[None] = asyncio.Future() | |
| item = task, fut | |
| self._waiters.append(item) | |
| try: | |
| await fut | |
| except CancelledError: | |
| self._waiters.remove(item) | |
| if self._owner_task is task: | |
| self.release() | |
| raise | |
| self._waiters.remove(item) | |
| def acquire_nowait(self) -> None: | |
| task = cast(asyncio.Task, current_task()) | |
| if self._owner_task is None and not self._waiters: | |
| self._owner_task = task | |
| return | |
| if self._owner_task is task: | |
| raise RuntimeError("Attempted to acquire an already held Lock") | |
| raise WouldBlock | |
| def locked(self) -> bool: | |
| return self._owner_task is not None | |
| def release(self) -> None: | |
| if self._owner_task != current_task(): | |
| raise RuntimeError("The current task is not holding this lock") | |
| for task, fut in self._waiters: | |
| if not fut.cancelled(): | |
| self._owner_task = task | |
| fut.set_result(None) | |
| return | |
| self._owner_task = None | |
| def statistics(self) -> LockStatistics: | |
| task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None | |
| return LockStatistics(self.locked(), task_info, len(self._waiters)) | |
| class Semaphore(BaseSemaphore): | |
| def __new__( | |
| cls, | |
| initial_value: int, | |
| *, | |
| max_value: int | None = None, | |
| fast_acquire: bool = False, | |
| ) -> Semaphore: | |
| return object.__new__(cls) | |
| def __init__( | |
| self, | |
| initial_value: int, | |
| *, | |
| max_value: int | None = None, | |
| fast_acquire: bool = False, | |
| ): | |
| super().__init__(initial_value, max_value=max_value) | |
| self._value = initial_value | |
| self._max_value = max_value | |
| self._fast_acquire = fast_acquire | |
| self._waiters: deque[asyncio.Future[None]] = deque() | |
| async def acquire(self) -> None: | |
| if self._value > 0 and not self._waiters: | |
| await AsyncIOBackend.checkpoint_if_cancelled() | |
| self._value -= 1 | |
| # Unless on the "fast path", yield control of the event loop so that other | |
| # tasks can run too | |
| if not self._fast_acquire: | |
| try: | |
| await AsyncIOBackend.cancel_shielded_checkpoint() | |
| except CancelledError: | |
| self.release() | |
| raise | |
| return | |
| fut: asyncio.Future[None] = asyncio.Future() | |
| self._waiters.append(fut) | |
| try: | |
| await fut | |
| except CancelledError: | |
| try: | |
| self._waiters.remove(fut) | |
| except ValueError: | |
| self.release() | |
| raise | |
| def acquire_nowait(self) -> None: | |
| if self._value == 0: | |
| raise WouldBlock | |
| self._value -= 1 | |
| def release(self) -> None: | |
| if self._max_value is not None and self._value == self._max_value: | |
| raise ValueError("semaphore released too many times") | |
| for fut in self._waiters: | |
| if not fut.cancelled(): | |
| fut.set_result(None) | |
| self._waiters.remove(fut) | |
| return | |
| self._value += 1 | |
| def value(self) -> int: | |
| return self._value | |
| def max_value(self) -> int | None: | |
| return self._max_value | |
| def statistics(self) -> SemaphoreStatistics: | |
| return SemaphoreStatistics(len(self._waiters)) | |
| class CapacityLimiter(BaseCapacityLimiter): | |
| _total_tokens: float = 0 | |
| def __new__(cls, total_tokens: float) -> CapacityLimiter: | |
| return object.__new__(cls) | |
| def __init__(self, total_tokens: float): | |
| self._borrowers: set[Any] = set() | |
| self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() | |
| self.total_tokens = total_tokens | |
| async def __aenter__(self) -> None: | |
| await self.acquire() | |
| async def __aexit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc_val: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| self.release() | |
| def total_tokens(self) -> float: | |
| return self._total_tokens | |
| def total_tokens(self, value: float) -> None: | |
| if not isinstance(value, int) and not math.isinf(value): | |
| raise TypeError("total_tokens must be an int or math.inf") | |
| if value < 1: | |
| raise ValueError("total_tokens must be >= 1") | |
| waiters_to_notify = max(value - self._total_tokens, 0) | |
| self._total_tokens = value | |
| # Notify waiting tasks that they have acquired the limiter | |
| while self._wait_queue and waiters_to_notify: | |
| event = self._wait_queue.popitem(last=False)[1] | |
| event.set() | |
| waiters_to_notify -= 1 | |
| def borrowed_tokens(self) -> int: | |
| return len(self._borrowers) | |
| def available_tokens(self) -> float: | |
| return self._total_tokens - len(self._borrowers) | |
| def acquire_nowait(self) -> None: | |
| self.acquire_on_behalf_of_nowait(current_task()) | |
| def acquire_on_behalf_of_nowait(self, borrower: object) -> None: | |
| if borrower in self._borrowers: | |
| raise RuntimeError( | |
| "this borrower is already holding one of this CapacityLimiter's tokens" | |
| ) | |
| if self._wait_queue or len(self._borrowers) >= self._total_tokens: | |
| raise WouldBlock | |
| self._borrowers.add(borrower) | |
| async def acquire(self) -> None: | |
| return await self.acquire_on_behalf_of(current_task()) | |
| async def acquire_on_behalf_of(self, borrower: object) -> None: | |
| await AsyncIOBackend.checkpoint_if_cancelled() | |
| try: | |
| self.acquire_on_behalf_of_nowait(borrower) | |
| except WouldBlock: | |
| event = asyncio.Event() | |
| self._wait_queue[borrower] = event | |
| try: | |
| await event.wait() | |
| except BaseException: | |
| self._wait_queue.pop(borrower, None) | |
| raise | |
| self._borrowers.add(borrower) | |
| else: | |
| try: | |
| await AsyncIOBackend.cancel_shielded_checkpoint() | |
| except BaseException: | |
| self.release() | |
| raise | |
| def release(self) -> None: | |
| self.release_on_behalf_of(current_task()) | |
| def release_on_behalf_of(self, borrower: object) -> None: | |
| try: | |
| self._borrowers.remove(borrower) | |
| except KeyError: | |
| raise RuntimeError( | |
| "this borrower isn't holding any of this CapacityLimiter's tokens" | |
| ) from None | |
| # Notify the next task in line if this limiter has free capacity now | |
| if self._wait_queue and len(self._borrowers) < self._total_tokens: | |
| event = self._wait_queue.popitem(last=False)[1] | |
| event.set() | |
| def statistics(self) -> CapacityLimiterStatistics: | |
| return CapacityLimiterStatistics( | |
| self.borrowed_tokens, | |
| self.total_tokens, | |
| tuple(self._borrowers), | |
| len(self._wait_queue), | |
| ) | |
| _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") | |
| # | |
| # Operating system signals | |
| # | |
| class _SignalReceiver: | |
| def __init__(self, signals: tuple[Signals, ...]): | |
| self._signals = signals | |
| self._loop = get_running_loop() | |
| self._signal_queue: deque[Signals] = deque() | |
| self._future: asyncio.Future = asyncio.Future() | |
| self._handled_signals: set[Signals] = set() | |
| def _deliver(self, signum: Signals) -> None: | |
| self._signal_queue.append(signum) | |
| if not self._future.done(): | |
| self._future.set_result(None) | |
| def __enter__(self) -> _SignalReceiver: | |
| for sig in set(self._signals): | |
| self._loop.add_signal_handler(sig, self._deliver, sig) | |
| self._handled_signals.add(sig) | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc_val: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| for sig in self._handled_signals: | |
| self._loop.remove_signal_handler(sig) | |
| def __aiter__(self) -> _SignalReceiver: | |
| return self | |
| async def __anext__(self) -> Signals: | |
| await AsyncIOBackend.checkpoint() | |
| if not self._signal_queue: | |
| self._future = asyncio.Future() | |
| await self._future | |
| return self._signal_queue.popleft() | |
| # | |
| # Testing and debugging | |
| # | |
| class AsyncIOTaskInfo(TaskInfo): | |
| def __init__(self, task: asyncio.Task): | |
| task_state = _task_states.get(task) | |
| if task_state is None: | |
| parent_id = None | |
| else: | |
| parent_id = task_state.parent_id | |
| coro = task.get_coro() | |
| assert coro is not None, "created TaskInfo from a completed Task" | |
| super().__init__(id(task), parent_id, task.get_name(), coro) | |
| self._task = weakref.ref(task) | |
| def has_pending_cancellation(self) -> bool: | |
| if not (task := self._task()): | |
| # If the task isn't around anymore, it won't have a pending cancellation | |
| return False | |
| if task._must_cancel: # type: ignore[attr-defined] | |
| return True | |
| elif ( | |
| isinstance(task._fut_waiter, asyncio.Future) # type: ignore[attr-defined] | |
| and task._fut_waiter.cancelled() # type: ignore[attr-defined] | |
| ): | |
| return True | |
| if task_state := _task_states.get(task): | |
| if cancel_scope := task_state.cancel_scope: | |
| return cancel_scope._effectively_cancelled | |
| return False | |
| class TestRunner(abc.TestRunner): | |
| _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]] | |
| def __init__( | |
| self, | |
| *, | |
| debug: bool | None = None, | |
| use_uvloop: bool = False, | |
| loop_factory: Callable[[], AbstractEventLoop] | None = None, | |
| ) -> None: | |
| if use_uvloop and loop_factory is None: | |
| import uvloop | |
| loop_factory = uvloop.new_event_loop | |
| self._runner = Runner(debug=debug, loop_factory=loop_factory) | |
| self._exceptions: list[BaseException] = [] | |
| self._runner_task: asyncio.Task | None = None | |
| def __enter__(self) -> TestRunner: | |
| self._runner.__enter__() | |
| self.get_loop().set_exception_handler(self._exception_handler) | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc_val: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| self._runner.__exit__(exc_type, exc_val, exc_tb) | |
| def get_loop(self) -> AbstractEventLoop: | |
| return self._runner.get_loop() | |
| def _exception_handler( | |
| self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] | |
| ) -> None: | |
| if isinstance(context.get("exception"), Exception): | |
| self._exceptions.append(context["exception"]) | |
| else: | |
| loop.default_exception_handler(context) | |
| def _raise_async_exceptions(self) -> None: | |
| # Re-raise any exceptions raised in asynchronous callbacks | |
| if self._exceptions: | |
| exceptions, self._exceptions = self._exceptions, [] | |
| if len(exceptions) == 1: | |
| raise exceptions[0] | |
| elif exceptions: | |
| raise BaseExceptionGroup( | |
| "Multiple exceptions occurred in asynchronous callbacks", exceptions | |
| ) | |
| async def _run_tests_and_fixtures( | |
| self, | |
| receive_stream: MemoryObjectReceiveStream[ | |
| tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]] | |
| ], | |
| ) -> None: | |
| from _pytest.outcomes import OutcomeException | |
| with receive_stream, self._send_stream: | |
| async for coro, future in receive_stream: | |
| try: | |
| retval = await coro | |
| except CancelledError as exc: | |
| if not future.cancelled(): | |
| future.cancel(*exc.args) | |
| raise | |
| except BaseException as exc: | |
| if not future.cancelled(): | |
| future.set_exception(exc) | |
| if not isinstance(exc, (Exception, OutcomeException)): | |
| raise | |
| else: | |
| if not future.cancelled(): | |
| future.set_result(retval) | |
| async def _call_in_runner_task( | |
| self, | |
| func: Callable[P, Awaitable[T_Retval]], | |
| *args: P.args, | |
| **kwargs: P.kwargs, | |
| ) -> T_Retval: | |
| if not self._runner_task: | |
| self._send_stream, receive_stream = create_memory_object_stream[ | |
| tuple[Awaitable[Any], asyncio.Future] | |
| ](1) | |
| self._runner_task = self.get_loop().create_task( | |
| self._run_tests_and_fixtures(receive_stream) | |
| ) | |
| coro = func(*args, **kwargs) | |
| future: asyncio.Future[T_Retval] = self.get_loop().create_future() | |
| self._send_stream.send_nowait((coro, future)) | |
| return await future | |
| def run_asyncgen_fixture( | |
| self, | |
| fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], | |
| kwargs: dict[str, Any], | |
| ) -> Iterable[T_Retval]: | |
| asyncgen = fixture_func(**kwargs) | |
| fixturevalue: T_Retval = self.get_loop().run_until_complete( | |
| self._call_in_runner_task(asyncgen.asend, None) | |
| ) | |
| self._raise_async_exceptions() | |
| yield fixturevalue | |
| try: | |
| self.get_loop().run_until_complete( | |
| self._call_in_runner_task(asyncgen.asend, None) | |
| ) | |
| except StopAsyncIteration: | |
| self._raise_async_exceptions() | |
| else: | |
| self.get_loop().run_until_complete(asyncgen.aclose()) | |
| raise RuntimeError("Async generator fixture did not stop") | |
| def run_fixture( | |
| self, | |
| fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], | |
| kwargs: dict[str, Any], | |
| ) -> T_Retval: | |
| retval = self.get_loop().run_until_complete( | |
| self._call_in_runner_task(fixture_func, **kwargs) | |
| ) | |
| self._raise_async_exceptions() | |
| return retval | |
| def run_test( | |
| self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] | |
| ) -> None: | |
| try: | |
| self.get_loop().run_until_complete( | |
| self._call_in_runner_task(test_func, **kwargs) | |
| ) | |
| except Exception as exc: | |
| self._exceptions.append(exc) | |
| self._raise_async_exceptions() | |
| class AsyncIOBackend(AsyncBackend): | |
| def run( | |
| cls, | |
| func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], | |
| args: tuple[Unpack[PosArgsT]], | |
| kwargs: dict[str, Any], | |
| options: dict[str, Any], | |
| ) -> T_Retval: | |
| async def wrapper() -> T_Retval: | |
| task = cast(asyncio.Task, current_task()) | |
| task.set_name(get_callable_name(func)) | |
| _task_states[task] = TaskState(None, None) | |
| try: | |
| return await func(*args) | |
| finally: | |
| del _task_states[task] | |
| debug = options.get("debug", None) | |
| loop_factory = options.get("loop_factory", None) | |
| if loop_factory is None and options.get("use_uvloop", False): | |
| import uvloop | |
| loop_factory = uvloop.new_event_loop | |
| with Runner(debug=debug, loop_factory=loop_factory) as runner: | |
| return runner.run(wrapper()) | |
| def current_token(cls) -> object: | |
| return get_running_loop() | |
| def current_time(cls) -> float: | |
| return get_running_loop().time() | |
| def cancelled_exception_class(cls) -> type[BaseException]: | |
| return CancelledError | |
| async def checkpoint(cls) -> None: | |
| await sleep(0) | |
| async def checkpoint_if_cancelled(cls) -> None: | |
| task = current_task() | |
| if task is None: | |
| return | |
| try: | |
| cancel_scope = _task_states[task].cancel_scope | |
| except KeyError: | |
| return | |
| while cancel_scope: | |
| if cancel_scope.cancel_called: | |
| await sleep(0) | |
| elif cancel_scope.shield: | |
| break | |
| else: | |
| cancel_scope = cancel_scope._parent_scope | |
| async def cancel_shielded_checkpoint(cls) -> None: | |
| with CancelScope(shield=True): | |
| await sleep(0) | |
| async def sleep(cls, delay: float) -> None: | |
| await sleep(delay) | |
| def create_cancel_scope( | |
| cls, *, deadline: float = math.inf, shield: bool = False | |
| ) -> CancelScope: | |
| return CancelScope(deadline=deadline, shield=shield) | |
| def current_effective_deadline(cls) -> float: | |
| if (task := current_task()) is None: | |
| return math.inf | |
| try: | |
| cancel_scope = _task_states[task].cancel_scope | |
| except KeyError: | |
| return math.inf | |
| deadline = math.inf | |
| while cancel_scope: | |
| deadline = min(deadline, cancel_scope.deadline) | |
| if cancel_scope._cancel_called: | |
| deadline = -math.inf | |
| break | |
| elif cancel_scope.shield: | |
| break | |
| else: | |
| cancel_scope = cancel_scope._parent_scope | |
| return deadline | |
| def create_task_group(cls) -> abc.TaskGroup: | |
| return TaskGroup() | |
| def create_event(cls) -> abc.Event: | |
| return Event() | |
| def create_lock(cls, *, fast_acquire: bool) -> abc.Lock: | |
| return Lock(fast_acquire=fast_acquire) | |
| def create_semaphore( | |
| cls, | |
| initial_value: int, | |
| *, | |
| max_value: int | None = None, | |
| fast_acquire: bool = False, | |
| ) -> abc.Semaphore: | |
| return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire) | |
| def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: | |
| return CapacityLimiter(total_tokens) | |
| async def run_sync_in_worker_thread( # type: ignore[return] | |
| cls, | |
| func: Callable[[Unpack[PosArgsT]], T_Retval], | |
| args: tuple[Unpack[PosArgsT]], | |
| abandon_on_cancel: bool = False, | |
| limiter: abc.CapacityLimiter | None = None, | |
| ) -> T_Retval: | |
| await cls.checkpoint() | |
| # If this is the first run in this event loop thread, set up the necessary | |
| # variables | |
| try: | |
| idle_workers = _threadpool_idle_workers.get() | |
| workers = _threadpool_workers.get() | |
| except LookupError: | |
| idle_workers = deque() | |
| workers = set() | |
| _threadpool_idle_workers.set(idle_workers) | |
| _threadpool_workers.set(workers) | |
| async with limiter or cls.current_default_thread_limiter(): | |
| with CancelScope(shield=not abandon_on_cancel) as scope: | |
| future = asyncio.Future[T_Retval]() | |
| root_task = find_root_task() | |
| if not idle_workers: | |
| worker = WorkerThread(root_task, workers, idle_workers) | |
| worker.start() | |
| workers.add(worker) | |
| root_task.add_done_callback( | |
| worker.stop, context=contextvars.Context() | |
| ) | |
| else: | |
| worker = idle_workers.pop() | |
| # Prune any other workers that have been idle for MAX_IDLE_TIME | |
| # seconds or longer | |
| now = cls.current_time() | |
| while idle_workers: | |
| if ( | |
| now - idle_workers[0].idle_since | |
| < WorkerThread.MAX_IDLE_TIME | |
| ): | |
| break | |
| expired_worker = idle_workers.popleft() | |
| expired_worker.root_task.remove_done_callback( | |
| expired_worker.stop | |
| ) | |
| expired_worker.stop() | |
| context = copy_context() | |
| context.run(sniffio.current_async_library_cvar.set, None) | |
| if abandon_on_cancel or scope._parent_scope is None: | |
| worker_scope = scope | |
| else: | |
| worker_scope = scope._parent_scope | |
| worker.queue.put_nowait((context, func, args, future, worker_scope)) | |
| return await future | |
| def check_cancelled(cls) -> None: | |
| scope: CancelScope | None = threadlocals.current_cancel_scope | |
| while scope is not None: | |
| if scope.cancel_called: | |
| raise CancelledError(f"Cancelled by cancel scope {id(scope):x}") | |
| if scope.shield: | |
| return | |
| scope = scope._parent_scope | |
| def run_async_from_thread( | |
| cls, | |
| func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], | |
| args: tuple[Unpack[PosArgsT]], | |
| token: object, | |
| ) -> T_Retval: | |
| async def task_wrapper(scope: CancelScope) -> T_Retval: | |
| __tracebackhide__ = True | |
| task = cast(asyncio.Task, current_task()) | |
| _task_states[task] = TaskState(None, scope) | |
| scope._tasks.add(task) | |
| try: | |
| return await func(*args) | |
| except CancelledError as exc: | |
| raise concurrent.futures.CancelledError(str(exc)) from None | |
| finally: | |
| scope._tasks.discard(task) | |
| loop = cast(AbstractEventLoop, token) | |
| context = copy_context() | |
| context.run(sniffio.current_async_library_cvar.set, "asyncio") | |
| wrapper = task_wrapper(threadlocals.current_cancel_scope) | |
| f: concurrent.futures.Future[T_Retval] = context.run( | |
| asyncio.run_coroutine_threadsafe, wrapper, loop | |
| ) | |
| return f.result() | |
| def run_sync_from_thread( | |
| cls, | |
| func: Callable[[Unpack[PosArgsT]], T_Retval], | |
| args: tuple[Unpack[PosArgsT]], | |
| token: object, | |
| ) -> T_Retval: | |
| def wrapper() -> None: | |
| try: | |
| sniffio.current_async_library_cvar.set("asyncio") | |
| f.set_result(func(*args)) | |
| except BaseException as exc: | |
| f.set_exception(exc) | |
| if not isinstance(exc, Exception): | |
| raise | |
| f: concurrent.futures.Future[T_Retval] = Future() | |
| loop = cast(AbstractEventLoop, token) | |
| loop.call_soon_threadsafe(wrapper) | |
| return f.result() | |
| def create_blocking_portal(cls) -> abc.BlockingPortal: | |
| return BlockingPortal() | |
| async def open_process( | |
| cls, | |
| command: StrOrBytesPath | Sequence[StrOrBytesPath], | |
| *, | |
| stdin: int | IO[Any] | None, | |
| stdout: int | IO[Any] | None, | |
| stderr: int | IO[Any] | None, | |
| **kwargs: Any, | |
| ) -> Process: | |
| await cls.checkpoint() | |
| if isinstance(command, PathLike): | |
| command = os.fspath(command) | |
| if isinstance(command, (str, bytes)): | |
| process = await asyncio.create_subprocess_shell( | |
| command, | |
| stdin=stdin, | |
| stdout=stdout, | |
| stderr=stderr, | |
| **kwargs, | |
| ) | |
| else: | |
| process = await asyncio.create_subprocess_exec( | |
| *command, | |
| stdin=stdin, | |
| stdout=stdout, | |
| stderr=stderr, | |
| **kwargs, | |
| ) | |
| stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None | |
| stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None | |
| stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None | |
| return Process(process, stdin_stream, stdout_stream, stderr_stream) | |
| def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None: | |
| create_task( | |
| _shutdown_process_pool_on_exit(workers), | |
| name="AnyIO process pool shutdown task", | |
| ) | |
| find_root_task().add_done_callback( | |
| partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type] | |
| ) | |
| async def connect_tcp( | |
| cls, host: str, port: int, local_address: IPSockAddrType | None = None | |
| ) -> abc.SocketStream: | |
| transport, protocol = cast( | |
| tuple[asyncio.Transport, StreamProtocol], | |
| await get_running_loop().create_connection( | |
| StreamProtocol, host, port, local_addr=local_address | |
| ), | |
| ) | |
| transport.pause_reading() | |
| return SocketStream(transport, protocol) | |
| async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream: | |
| await cls.checkpoint() | |
| loop = get_running_loop() | |
| raw_socket = socket.socket(socket.AF_UNIX) | |
| raw_socket.setblocking(False) | |
| while True: | |
| try: | |
| raw_socket.connect(path) | |
| except BlockingIOError: | |
| f: asyncio.Future = asyncio.Future() | |
| loop.add_writer(raw_socket, f.set_result, None) | |
| f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) | |
| await f | |
| except BaseException: | |
| raw_socket.close() | |
| raise | |
| else: | |
| return UNIXSocketStream(raw_socket) | |
| def create_tcp_listener(cls, sock: socket.socket) -> SocketListener: | |
| return TCPSocketListener(sock) | |
| def create_unix_listener(cls, sock: socket.socket) -> SocketListener: | |
| return UNIXSocketListener(sock) | |
| async def create_udp_socket( | |
| cls, | |
| family: AddressFamily, | |
| local_address: IPSockAddrType | None, | |
| remote_address: IPSockAddrType | None, | |
| reuse_port: bool, | |
| ) -> UDPSocket | ConnectedUDPSocket: | |
| transport, protocol = await get_running_loop().create_datagram_endpoint( | |
| DatagramProtocol, | |
| local_addr=local_address, | |
| remote_addr=remote_address, | |
| family=family, | |
| reuse_port=reuse_port, | |
| ) | |
| if protocol.exception: | |
| transport.close() | |
| raise protocol.exception | |
| if not remote_address: | |
| return UDPSocket(transport, protocol) | |
| else: | |
| return ConnectedUDPSocket(transport, protocol) | |
| async def create_unix_datagram_socket( # type: ignore[override] | |
| cls, raw_socket: socket.socket, remote_path: str | bytes | None | |
| ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket: | |
| await cls.checkpoint() | |
| loop = get_running_loop() | |
| if remote_path: | |
| while True: | |
| try: | |
| raw_socket.connect(remote_path) | |
| except BlockingIOError: | |
| f: asyncio.Future = asyncio.Future() | |
| loop.add_writer(raw_socket, f.set_result, None) | |
| f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) | |
| await f | |
| except BaseException: | |
| raw_socket.close() | |
| raise | |
| else: | |
| return ConnectedUNIXDatagramSocket(raw_socket) | |
| else: | |
| return UNIXDatagramSocket(raw_socket) | |
| async def getaddrinfo( | |
| cls, | |
| host: bytes | str | None, | |
| port: str | int | None, | |
| *, | |
| family: int | AddressFamily = 0, | |
| type: int | SocketKind = 0, | |
| proto: int = 0, | |
| flags: int = 0, | |
| ) -> Sequence[ | |
| tuple[ | |
| AddressFamily, | |
| SocketKind, | |
| int, | |
| str, | |
| tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes], | |
| ] | |
| ]: | |
| return await get_running_loop().getaddrinfo( | |
| host, port, family=family, type=type, proto=proto, flags=flags | |
| ) | |
| async def getnameinfo( | |
| cls, sockaddr: IPSockAddrType, flags: int = 0 | |
| ) -> tuple[str, str]: | |
| return await get_running_loop().getnameinfo(sockaddr, flags) | |
| async def wait_readable(cls, obj: FileDescriptorLike) -> None: | |
| await cls.checkpoint() | |
| try: | |
| read_events = _read_events.get() | |
| except LookupError: | |
| read_events = {} | |
| _read_events.set(read_events) | |
| if not isinstance(obj, int): | |
| obj = obj.fileno() | |
| if read_events.get(obj): | |
| raise BusyResourceError("reading from") | |
| loop = get_running_loop() | |
| event = asyncio.Event() | |
| try: | |
| loop.add_reader(obj, event.set) | |
| except NotImplementedError: | |
| from anyio._core._asyncio_selector_thread import get_selector | |
| selector = get_selector() | |
| selector.add_reader(obj, event.set) | |
| remove_reader = selector.remove_reader | |
| else: | |
| remove_reader = loop.remove_reader | |
| read_events[obj] = event | |
| try: | |
| await event.wait() | |
| finally: | |
| remove_reader(obj) | |
| del read_events[obj] | |
| async def wait_writable(cls, obj: FileDescriptorLike) -> None: | |
| await cls.checkpoint() | |
| try: | |
| write_events = _write_events.get() | |
| except LookupError: | |
| write_events = {} | |
| _write_events.set(write_events) | |
| if not isinstance(obj, int): | |
| obj = obj.fileno() | |
| if write_events.get(obj): | |
| raise BusyResourceError("writing to") | |
| loop = get_running_loop() | |
| event = asyncio.Event() | |
| try: | |
| loop.add_writer(obj, event.set) | |
| except NotImplementedError: | |
| from anyio._core._asyncio_selector_thread import get_selector | |
| selector = get_selector() | |
| selector.add_writer(obj, event.set) | |
| remove_writer = selector.remove_writer | |
| else: | |
| remove_writer = loop.remove_writer | |
| write_events[obj] = event | |
| try: | |
| await event.wait() | |
| finally: | |
| del write_events[obj] | |
| remove_writer(obj) | |
| def current_default_thread_limiter(cls) -> CapacityLimiter: | |
| try: | |
| return _default_thread_limiter.get() | |
| except LookupError: | |
| limiter = CapacityLimiter(40) | |
| _default_thread_limiter.set(limiter) | |
| return limiter | |
| def open_signal_receiver( | |
| cls, *signals: Signals | |
| ) -> AbstractContextManager[AsyncIterator[Signals]]: | |
| return _SignalReceiver(signals) | |
| def get_current_task(cls) -> TaskInfo: | |
| return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type] | |
| def get_running_tasks(cls) -> Sequence[TaskInfo]: | |
| return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()] | |
| async def wait_all_tasks_blocked(cls) -> None: | |
| await cls.checkpoint() | |
| this_task = current_task() | |
| while True: | |
| for task in all_tasks(): | |
| if task is this_task: | |
| continue | |
| waiter = task._fut_waiter # type: ignore[attr-defined] | |
| if waiter is None or waiter.done(): | |
| await sleep(0.1) | |
| break | |
| else: | |
| return | |
| def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: | |
| return TestRunner(**options) | |
| backend_class = AsyncIOBackend | |