| | from __future__ import annotations |
| |
|
| | import array |
| | import asyncio |
| | import concurrent.futures |
| | import contextvars |
| | import math |
| | import os |
| | import socket |
| | import sys |
| | import threading |
| | import weakref |
| | from asyncio import ( |
| | AbstractEventLoop, |
| | CancelledError, |
| | all_tasks, |
| | create_task, |
| | current_task, |
| | get_running_loop, |
| | sleep, |
| | ) |
| | from asyncio.base_events import _run_until_complete_cb |
| | from collections import OrderedDict, deque |
| | from collections.abc import ( |
| | AsyncGenerator, |
| | AsyncIterator, |
| | Awaitable, |
| | Callable, |
| | Collection, |
| | Coroutine, |
| | Iterable, |
| | Sequence, |
| | ) |
| | from concurrent.futures import Future |
| | from contextlib import AbstractContextManager, suppress |
| | from contextvars import Context, copy_context |
| | from dataclasses import dataclass, field |
| | from functools import partial, wraps |
| | from inspect import ( |
| | CORO_RUNNING, |
| | CORO_SUSPENDED, |
| | getcoroutinestate, |
| | iscoroutine, |
| | ) |
| | from io import IOBase |
| | from os import PathLike |
| | from queue import Queue |
| | from signal import Signals |
| | from socket import AddressFamily, SocketKind |
| | from threading import Thread |
| | from types import CodeType, TracebackType |
| | from typing import ( |
| | IO, |
| | TYPE_CHECKING, |
| | Any, |
| | Optional, |
| | TypeVar, |
| | cast, |
| | ) |
| | from weakref import WeakKeyDictionary |
| |
|
| | from .. import ( |
| | CapacityLimiterStatistics, |
| | EventStatistics, |
| | LockStatistics, |
| | TaskInfo, |
| | abc, |
| | ) |
| | from .._core._eventloop import ( |
| | claim_worker_thread, |
| | set_current_async_library, |
| | threadlocals, |
| | ) |
| | from .._core._exceptions import ( |
| | BrokenResourceError, |
| | BusyResourceError, |
| | ClosedResourceError, |
| | EndOfStream, |
| | RunFinishedError, |
| | WouldBlock, |
| | iterate_exceptions, |
| | ) |
| | from .._core._sockets import convert_ipv6_sockaddr |
| | from .._core._streams import create_memory_object_stream |
| | from .._core._synchronization import ( |
| | CapacityLimiter as BaseCapacityLimiter, |
| | ) |
| | from .._core._synchronization import Event as BaseEvent |
| | from .._core._synchronization import Lock as BaseLock |
| | from .._core._synchronization import ( |
| | ResourceGuard, |
| | SemaphoreStatistics, |
| | ) |
| | from .._core._synchronization import Semaphore as BaseSemaphore |
| | from .._core._tasks import CancelScope as BaseCancelScope |
| | from ..abc import ( |
| | AsyncBackend, |
| | IPSockAddrType, |
| | SocketListener, |
| | UDPPacketType, |
| | UNIXDatagramPacketType, |
| | ) |
| | from ..abc._eventloop import StrOrBytesPath |
| | from ..lowlevel import RunVar |
| | from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
| |
|
| | if TYPE_CHECKING: |
| | from _typeshed import FileDescriptorLike |
| | else: |
| | FileDescriptorLike = object |
| |
|
| | if sys.version_info >= (3, 10): |
| | from typing import ParamSpec |
| | else: |
| | from typing_extensions import ParamSpec |
| |
|
| | if sys.version_info >= (3, 11): |
| | from asyncio import Runner |
| | from typing import TypeVarTuple, Unpack |
| | else: |
| | import contextvars |
| | import enum |
| | import signal |
| | from asyncio import coroutines, events, exceptions, tasks |
| |
|
| | from exceptiongroup import BaseExceptionGroup |
| | from typing_extensions import TypeVarTuple, Unpack |
| |
|
| | class _State(enum.Enum): |
| | CREATED = "created" |
| | INITIALIZED = "initialized" |
| | CLOSED = "closed" |
| |
|
| | class Runner: |
| | |
| | def __init__( |
| | self, |
| | *, |
| | debug: bool | None = None, |
| | loop_factory: Callable[[], AbstractEventLoop] | None = None, |
| | ): |
| | self._state = _State.CREATED |
| | self._debug = debug |
| | self._loop_factory = loop_factory |
| | self._loop: AbstractEventLoop | None = None |
| | self._context = None |
| | self._interrupt_count = 0 |
| | self._set_event_loop = False |
| |
|
| | def __enter__(self) -> Runner: |
| | self._lazy_init() |
| | return self |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: TracebackType | None, |
| | ) -> None: |
| | self.close() |
| |
|
| | def close(self) -> None: |
| | """Shutdown and close event loop.""" |
| | loop = self._loop |
| | if self._state is not _State.INITIALIZED or loop is None: |
| | return |
| | try: |
| | _cancel_all_tasks(loop) |
| | loop.run_until_complete(loop.shutdown_asyncgens()) |
| | if hasattr(loop, "shutdown_default_executor"): |
| | loop.run_until_complete(loop.shutdown_default_executor()) |
| | else: |
| | loop.run_until_complete(_shutdown_default_executor(loop)) |
| | finally: |
| | if self._set_event_loop: |
| | events.set_event_loop(None) |
| | loop.close() |
| | self._loop = None |
| | self._state = _State.CLOSED |
| |
|
| | def get_loop(self) -> AbstractEventLoop: |
| | """Return embedded event loop.""" |
| | self._lazy_init() |
| | return self._loop |
| |
|
| | def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval: |
| | """Run a coroutine inside the embedded event loop.""" |
| | if not coroutines.iscoroutine(coro): |
| | raise ValueError(f"a coroutine was expected, got {coro!r}") |
| |
|
| | if events._get_running_loop() is not None: |
| | |
| | raise RuntimeError( |
| | "Runner.run() cannot be called from a running event loop" |
| | ) |
| |
|
| | self._lazy_init() |
| |
|
| | if context is None: |
| | context = self._context |
| | task = context.run(self._loop.create_task, coro) |
| |
|
| | if ( |
| | threading.current_thread() is threading.main_thread() |
| | and signal.getsignal(signal.SIGINT) is signal.default_int_handler |
| | ): |
| | sigint_handler = partial(self._on_sigint, main_task=task) |
| | try: |
| | signal.signal(signal.SIGINT, sigint_handler) |
| | except ValueError: |
| | |
| | |
| | |
| | sigint_handler = None |
| | else: |
| | sigint_handler = None |
| |
|
| | self._interrupt_count = 0 |
| | try: |
| | return self._loop.run_until_complete(task) |
| | except exceptions.CancelledError: |
| | if self._interrupt_count > 0: |
| | uncancel = getattr(task, "uncancel", None) |
| | if uncancel is not None and uncancel() == 0: |
| | raise KeyboardInterrupt |
| | raise |
| | finally: |
| | if ( |
| | sigint_handler is not None |
| | and signal.getsignal(signal.SIGINT) is sigint_handler |
| | ): |
| | signal.signal(signal.SIGINT, signal.default_int_handler) |
| |
|
| | def _lazy_init(self) -> None: |
| | if self._state is _State.CLOSED: |
| | raise RuntimeError("Runner is closed") |
| | if self._state is _State.INITIALIZED: |
| | return |
| | if self._loop_factory is None: |
| | self._loop = events.new_event_loop() |
| | if not self._set_event_loop: |
| | |
| | |
| | events.set_event_loop(self._loop) |
| | self._set_event_loop = True |
| | else: |
| | self._loop = self._loop_factory() |
| | if self._debug is not None: |
| | self._loop.set_debug(self._debug) |
| | self._context = contextvars.copy_context() |
| | self._state = _State.INITIALIZED |
| |
|
| | def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None: |
| | self._interrupt_count += 1 |
| | if self._interrupt_count == 1 and not main_task.done(): |
| | main_task.cancel() |
| | |
| | self._loop.call_soon_threadsafe(lambda: None) |
| | return |
| | raise KeyboardInterrupt() |
| |
|
| | def _cancel_all_tasks(loop: AbstractEventLoop) -> None: |
| | to_cancel = tasks.all_tasks(loop) |
| | if not to_cancel: |
| | return |
| |
|
| | for task in to_cancel: |
| | task.cancel() |
| |
|
| | loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True)) |
| |
|
| | for task in to_cancel: |
| | if task.cancelled(): |
| | continue |
| | if task.exception() is not None: |
| | loop.call_exception_handler( |
| | { |
| | "message": "unhandled exception during asyncio.run() shutdown", |
| | "exception": task.exception(), |
| | "task": task, |
| | } |
| | ) |
| |
|
| | async def _shutdown_default_executor(loop: AbstractEventLoop) -> None: |
| | """Schedule the shutdown of the default executor.""" |
| |
|
| | def _do_shutdown(future: asyncio.futures.Future) -> None: |
| | try: |
| | loop._default_executor.shutdown(wait=True) |
| | loop.call_soon_threadsafe(future.set_result, None) |
| | except Exception as ex: |
| | loop.call_soon_threadsafe(future.set_exception, ex) |
| |
|
| | loop._executor_shutdown_called = True |
| | if loop._default_executor is None: |
| | return |
| | future = loop.create_future() |
| | thread = threading.Thread(target=_do_shutdown, args=(future,)) |
| | thread.start() |
| | try: |
| | await future |
| | finally: |
| | thread.join() |
| |
|
| |
|
| | T_Retval = TypeVar("T_Retval") |
| | T_contra = TypeVar("T_contra", contravariant=True) |
| | PosArgsT = TypeVarTuple("PosArgsT") |
| | P = ParamSpec("P") |
| |
|
| | _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") |
| |
|
| |
|
| | def find_root_task() -> asyncio.Task: |
| | root_task = _root_task.get(None) |
| | if root_task is not None and not root_task.done(): |
| | return root_task |
| |
|
| | |
| | for task in all_tasks(): |
| | if task._callbacks and not task.done(): |
| | callbacks = [cb for cb, context in task._callbacks] |
| | for cb in callbacks: |
| | if ( |
| | cb is _run_until_complete_cb |
| | or getattr(cb, "__module__", None) == "uvloop.loop" |
| | ): |
| | _root_task.set(task) |
| | return task |
| |
|
| | |
| | task = cast(asyncio.Task, current_task()) |
| | state = _task_states.get(task) |
| | if state: |
| | cancel_scope = state.cancel_scope |
| | while cancel_scope and cancel_scope._parent_scope is not None: |
| | cancel_scope = cancel_scope._parent_scope |
| |
|
| | if cancel_scope is not None: |
| | return cast(asyncio.Task, cancel_scope._host_task) |
| |
|
| | return task |
| |
|
| |
|
| | def get_callable_name(func: Callable) -> str: |
| | module = getattr(func, "__module__", None) |
| | qualname = getattr(func, "__qualname__", None) |
| | return ".".join([x for x in (module, qualname) if x]) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary() |
| |
|
| |
|
| | def _task_started(task: asyncio.Task) -> bool: |
| | """Return ``True`` if the task has been started and has not finished.""" |
| | |
| | |
| | coro = task.get_coro() |
| | assert coro is not None |
| | try: |
| | return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) |
| | except AttributeError: |
| | |
| | raise Exception(f"Cannot determine if task {task} has started or not") from None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def is_anyio_cancellation(exc: CancelledError) -> bool: |
| | |
| | |
| | |
| | while True: |
| | if ( |
| | exc.args |
| | and isinstance(exc.args[0], str) |
| | and exc.args[0].startswith("Cancelled via cancel scope ") |
| | ): |
| | return True |
| |
|
| | if isinstance(exc.__context__, CancelledError): |
| | exc = exc.__context__ |
| | continue |
| |
|
| | return False |
| |
|
| |
|
| | class CancelScope(BaseCancelScope): |
| | def __new__( |
| | cls, *, deadline: float = math.inf, shield: bool = False |
| | ) -> CancelScope: |
| | return object.__new__(cls) |
| |
|
| | def __init__(self, deadline: float = math.inf, shield: bool = False): |
| | self._deadline = deadline |
| | self._shield = shield |
| | self._parent_scope: CancelScope | None = None |
| | self._child_scopes: set[CancelScope] = set() |
| | self._cancel_called = False |
| | self._cancel_reason: str | None = None |
| | self._cancelled_caught = False |
| | self._active = False |
| | self._timeout_handle: asyncio.TimerHandle | None = None |
| | self._cancel_handle: asyncio.Handle | None = None |
| | self._tasks: set[asyncio.Task] = set() |
| | self._host_task: asyncio.Task | None = None |
| | if sys.version_info >= (3, 11): |
| | self._pending_uncancellations: int | None = 0 |
| | else: |
| | self._pending_uncancellations = None |
| |
|
| | def __enter__(self) -> CancelScope: |
| | if self._active: |
| | raise RuntimeError( |
| | "Each CancelScope may only be used for a single 'with' block" |
| | ) |
| |
|
| | self._host_task = host_task = cast(asyncio.Task, current_task()) |
| | self._tasks.add(host_task) |
| | try: |
| | task_state = _task_states[host_task] |
| | except KeyError: |
| | task_state = TaskState(None, self) |
| | _task_states[host_task] = task_state |
| | else: |
| | self._parent_scope = task_state.cancel_scope |
| | task_state.cancel_scope = self |
| | if self._parent_scope is not None: |
| | |
| | |
| | self._parent_scope._child_scopes.add(self) |
| | self._parent_scope._tasks.discard(host_task) |
| |
|
| | self._timeout() |
| | self._active = True |
| |
|
| | |
| | if self._cancel_called: |
| | self._deliver_cancellation(self) |
| |
|
| | return self |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: TracebackType | None, |
| | ) -> bool: |
| | del exc_tb |
| |
|
| | if not self._active: |
| | raise RuntimeError("This cancel scope is not active") |
| | if current_task() is not self._host_task: |
| | raise RuntimeError( |
| | "Attempted to exit cancel scope in a different task than it was " |
| | "entered in" |
| | ) |
| |
|
| | assert self._host_task is not None |
| | host_task_state = _task_states.get(self._host_task) |
| | if host_task_state is None or host_task_state.cancel_scope is not self: |
| | raise RuntimeError( |
| | "Attempted to exit a cancel scope that isn't the current tasks's " |
| | "current cancel scope" |
| | ) |
| |
|
| | try: |
| | self._active = False |
| | if self._timeout_handle: |
| | self._timeout_handle.cancel() |
| | self._timeout_handle = None |
| |
|
| | self._tasks.remove(self._host_task) |
| | if self._parent_scope is not None: |
| | self._parent_scope._child_scopes.remove(self) |
| | self._parent_scope._tasks.add(self._host_task) |
| |
|
| | host_task_state.cancel_scope = self._parent_scope |
| |
|
| | |
| | |
| | self._restart_cancellation_in_parent() |
| |
|
| | |
| | |
| | |
| | if self._cancel_called and not self._parent_cancellation_is_visible_to_us: |
| | |
| | while self._pending_uncancellations: |
| | self._host_task.uncancel() |
| | self._pending_uncancellations -= 1 |
| |
|
| | |
| | cannot_swallow_exc_val = False |
| | if exc_val is not None: |
| | for exc in iterate_exceptions(exc_val): |
| | if isinstance(exc, CancelledError) and is_anyio_cancellation( |
| | exc |
| | ): |
| | self._cancelled_caught = True |
| | else: |
| | cannot_swallow_exc_val = True |
| |
|
| | return self._cancelled_caught and not cannot_swallow_exc_val |
| | else: |
| | if self._pending_uncancellations: |
| | assert self._parent_scope is not None |
| | assert self._parent_scope._pending_uncancellations is not None |
| | self._parent_scope._pending_uncancellations += ( |
| | self._pending_uncancellations |
| | ) |
| | self._pending_uncancellations = 0 |
| |
|
| | return False |
| | finally: |
| | self._host_task = None |
| | del exc_val |
| |
|
| | @property |
| | def _effectively_cancelled(self) -> bool: |
| | cancel_scope: CancelScope | None = self |
| | while cancel_scope is not None: |
| | if cancel_scope._cancel_called: |
| | return True |
| |
|
| | if cancel_scope.shield: |
| | return False |
| |
|
| | cancel_scope = cancel_scope._parent_scope |
| |
|
| | return False |
| |
|
| | @property |
| | def _parent_cancellation_is_visible_to_us(self) -> bool: |
| | return ( |
| | self._parent_scope is not None |
| | and not self.shield |
| | and self._parent_scope._effectively_cancelled |
| | ) |
| |
|
| | def _timeout(self) -> None: |
| | if self._deadline != math.inf: |
| | loop = get_running_loop() |
| | if loop.time() >= self._deadline: |
| | self.cancel("deadline exceeded") |
| | else: |
| | self._timeout_handle = loop.call_at(self._deadline, self._timeout) |
| |
|
| | def _deliver_cancellation(self, origin: CancelScope) -> bool: |
| | """ |
| | Deliver cancellation to directly contained tasks and nested cancel scopes. |
| | |
| | Schedule another run at the end if we still have tasks eligible for |
| | cancellation. |
| | |
| | :param origin: the cancel scope that originated the cancellation |
| | :return: ``True`` if the delivery needs to be retried on the next cycle |
| | |
| | """ |
| | should_retry = False |
| | current = current_task() |
| | for task in self._tasks: |
| | should_retry = True |
| | if task._must_cancel: |
| | continue |
| |
|
| | |
| | if task is not current and (task is self._host_task or _task_started(task)): |
| | waiter = task._fut_waiter |
| | if not isinstance(waiter, asyncio.Future) or not waiter.done(): |
| | task.cancel(origin._cancel_reason) |
| | if ( |
| | task is origin._host_task |
| | and origin._pending_uncancellations is not None |
| | ): |
| | origin._pending_uncancellations += 1 |
| |
|
| | |
| | |
| | for scope in self._child_scopes: |
| | if not scope._shield and not scope.cancel_called: |
| | should_retry = scope._deliver_cancellation(origin) or should_retry |
| |
|
| | |
| | if origin is self: |
| | if should_retry: |
| | self._cancel_handle = get_running_loop().call_soon( |
| | self._deliver_cancellation, origin |
| | ) |
| | else: |
| | self._cancel_handle = None |
| |
|
| | return should_retry |
| |
|
| | def _restart_cancellation_in_parent(self) -> None: |
| | """ |
| | Restart the cancellation effort in the closest directly cancelled parent scope. |
| | |
| | """ |
| | scope = self._parent_scope |
| | while scope is not None: |
| | if scope._cancel_called: |
| | if scope._cancel_handle is None: |
| | scope._deliver_cancellation(scope) |
| |
|
| | break |
| |
|
| | |
| | if scope._shield: |
| | break |
| |
|
| | scope = scope._parent_scope |
| |
|
| | def cancel(self, reason: str | None = None) -> None: |
| | if not self._cancel_called: |
| | if self._timeout_handle: |
| | self._timeout_handle.cancel() |
| | self._timeout_handle = None |
| |
|
| | self._cancel_called = True |
| | self._cancel_reason = f"Cancelled via cancel scope {id(self):x}" |
| | if task := current_task(): |
| | self._cancel_reason += f" by {task}" |
| |
|
| | if reason: |
| | self._cancel_reason += f"; reason: {reason}" |
| |
|
| | if self._host_task is not None: |
| | self._deliver_cancellation(self) |
| |
|
| | @property |
| | def deadline(self) -> float: |
| | return self._deadline |
| |
|
| | @deadline.setter |
| | def deadline(self, value: float) -> None: |
| | self._deadline = float(value) |
| | if self._timeout_handle is not None: |
| | self._timeout_handle.cancel() |
| | self._timeout_handle = None |
| |
|
| | if self._active and not self._cancel_called: |
| | self._timeout() |
| |
|
| | @property |
| | def cancel_called(self) -> bool: |
| | return self._cancel_called |
| |
|
| | @property |
| | def cancelled_caught(self) -> bool: |
| | return self._cancelled_caught |
| |
|
| | @property |
| | def shield(self) -> bool: |
| | return self._shield |
| |
|
| | @shield.setter |
| | def shield(self, value: bool) -> None: |
| | if self._shield != value: |
| | self._shield = value |
| | if not value: |
| | self._restart_cancellation_in_parent() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class TaskState: |
| | """ |
| | Encapsulates auxiliary task information that cannot be added to the Task instance |
| | itself because there are no guarantees about its implementation. |
| | """ |
| |
|
| | __slots__ = "parent_id", "cancel_scope", "__weakref__" |
| |
|
| | def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): |
| | self.parent_id = parent_id |
| | self.cancel_scope = cancel_scope |
| |
|
| |
|
| | _task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class _AsyncioTaskStatus(abc.TaskStatus): |
| | def __init__(self, future: asyncio.Future, parent_id: int): |
| | self._future = future |
| | self._parent_id = parent_id |
| |
|
| | def started(self, value: T_contra | None = None) -> None: |
| | try: |
| | self._future.set_result(value) |
| | except asyncio.InvalidStateError: |
| | if not self._future.cancelled(): |
| | raise RuntimeError( |
| | "called 'started' twice on the same task status" |
| | ) from None |
| |
|
| | task = cast(asyncio.Task, current_task()) |
| | _task_states[task].parent_id = self._parent_id |
| |
|
| |
|
| | if sys.version_info >= (3, 12): |
| | _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__ |
| | else: |
| | _eager_task_factory_code = None |
| |
|
| |
|
| | class TaskGroup(abc.TaskGroup): |
| | def __init__(self) -> None: |
| | self.cancel_scope: CancelScope = CancelScope() |
| | self._active = False |
| | self._exceptions: list[BaseException] = [] |
| | self._tasks: set[asyncio.Task] = set() |
| | self._on_completed_fut: asyncio.Future[None] | None = None |
| |
|
| | async def __aenter__(self) -> TaskGroup: |
| | self.cancel_scope.__enter__() |
| | self._active = True |
| | return self |
| |
|
| | async def __aexit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: TracebackType | None, |
| | ) -> bool: |
| | try: |
| | if exc_val is not None: |
| | self.cancel_scope.cancel() |
| | if not isinstance(exc_val, CancelledError): |
| | self._exceptions.append(exc_val) |
| |
|
| | loop = get_running_loop() |
| | try: |
| | if self._tasks: |
| | with CancelScope() as wait_scope: |
| | while self._tasks: |
| | self._on_completed_fut = loop.create_future() |
| |
|
| | try: |
| | await self._on_completed_fut |
| | except CancelledError as exc: |
| | |
| | |
| | wait_scope.shield = True |
| | self.cancel_scope.cancel() |
| |
|
| | |
| | |
| | |
| | if exc_val is None or ( |
| | isinstance(exc_val, CancelledError) |
| | and not is_anyio_cancellation(exc) |
| | ): |
| | exc_val = exc |
| |
|
| | self._on_completed_fut = None |
| | else: |
| | |
| | |
| | await AsyncIOBackend.cancel_shielded_checkpoint() |
| |
|
| | self._active = False |
| | if self._exceptions: |
| | |
| | |
| | |
| | |
| | raise BaseExceptionGroup( |
| | "unhandled errors in a TaskGroup", self._exceptions |
| | ) from None |
| | elif exc_val: |
| | raise exc_val |
| | except BaseException as exc: |
| | if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__): |
| | return True |
| |
|
| | raise |
| |
|
| | return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) |
| | finally: |
| | del exc_val, exc_tb, self._exceptions |
| |
|
| | def _spawn( |
| | self, |
| | func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], |
| | args: tuple[Unpack[PosArgsT]], |
| | name: object, |
| | task_status_future: asyncio.Future | None = None, |
| | ) -> asyncio.Task: |
| | def task_done(_task: asyncio.Task) -> None: |
| | if sys.version_info >= (3, 14) and self.cancel_scope._host_task is not None: |
| | asyncio.future_discard_from_awaited_by( |
| | _task, self.cancel_scope._host_task |
| | ) |
| |
|
| | task_state = _task_states[_task] |
| | assert task_state.cancel_scope is not None |
| | assert _task in task_state.cancel_scope._tasks |
| | task_state.cancel_scope._tasks.remove(_task) |
| | self._tasks.remove(task) |
| | del _task_states[_task] |
| |
|
| | if self._on_completed_fut is not None and not self._tasks: |
| | try: |
| | self._on_completed_fut.set_result(None) |
| | except asyncio.InvalidStateError: |
| | pass |
| |
|
| | try: |
| | exc = _task.exception() |
| | except CancelledError as e: |
| | while isinstance(e.__context__, CancelledError): |
| | e = e.__context__ |
| |
|
| | exc = e |
| |
|
| | if exc is not None: |
| | |
| | |
| | |
| | if task_status_future is not None and task_status_future.cancelled(): |
| | return |
| |
|
| | if task_status_future is None or task_status_future.done(): |
| | if not isinstance(exc, CancelledError): |
| | self._exceptions.append(exc) |
| |
|
| | if not self.cancel_scope._effectively_cancelled: |
| | self.cancel_scope.cancel() |
| | else: |
| | task_status_future.set_exception(exc) |
| | elif task_status_future is not None and not task_status_future.done(): |
| | task_status_future.set_exception( |
| | RuntimeError("Child exited without calling task_status.started()") |
| | ) |
| |
|
| | if not self._active: |
| | raise RuntimeError( |
| | "This task group is not active; no new tasks can be started." |
| | ) |
| |
|
| | kwargs = {} |
| | if task_status_future: |
| | parent_id = id(current_task()) |
| | kwargs["task_status"] = _AsyncioTaskStatus( |
| | task_status_future, id(self.cancel_scope._host_task) |
| | ) |
| | else: |
| | parent_id = id(self.cancel_scope._host_task) |
| |
|
| | coro = func(*args, **kwargs) |
| | if not iscoroutine(coro): |
| | prefix = f"{func.__module__}." if hasattr(func, "__module__") else "" |
| | raise TypeError( |
| | f"Expected {prefix}{func.__qualname__}() to return a coroutine, but " |
| | f"the return value ({coro!r}) is not a coroutine object" |
| | ) |
| |
|
| | name = get_callable_name(func) if name is None else str(name) |
| | loop = asyncio.get_running_loop() |
| | if ( |
| | (factory := loop.get_task_factory()) |
| | and getattr(factory, "__code__", None) is _eager_task_factory_code |
| | and (closure := getattr(factory, "__closure__", None)) |
| | ): |
| | custom_task_constructor = closure[0].cell_contents |
| | task = custom_task_constructor(coro, loop=loop, name=name) |
| | else: |
| | task = create_task(coro, name=name) |
| |
|
| | |
| | _task_states[task] = TaskState( |
| | parent_id=parent_id, cancel_scope=self.cancel_scope |
| | ) |
| | self.cancel_scope._tasks.add(task) |
| | self._tasks.add(task) |
| | if sys.version_info >= (3, 14) and self.cancel_scope._host_task is not None: |
| | asyncio.future_add_to_awaited_by(task, self.cancel_scope._host_task) |
| |
|
| | task.add_done_callback(task_done) |
| | return task |
| |
|
| | def start_soon( |
| | self, |
| | func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], |
| | *args: Unpack[PosArgsT], |
| | name: object = None, |
| | ) -> None: |
| | self._spawn(func, args, name) |
| |
|
| | async def start( |
| | self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None |
| | ) -> Any: |
| | future: asyncio.Future = asyncio.Future() |
| | task = self._spawn(func, args, name, future) |
| |
|
| | |
| | |
| | |
| | |
| | try: |
| | return await future |
| | except CancelledError: |
| | |
| | task.cancel() |
| | with CancelScope(shield=True), suppress(CancelledError): |
| | await task |
| |
|
| | raise |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]] |
| |
|
| |
|
| | class WorkerThread(Thread): |
| | MAX_IDLE_TIME = 10 |
| |
|
| | def __init__( |
| | self, |
| | root_task: asyncio.Task, |
| | workers: set[WorkerThread], |
| | idle_workers: deque[WorkerThread], |
| | ): |
| | super().__init__(name="AnyIO worker thread") |
| | self.root_task = root_task |
| | self.workers = workers |
| | self.idle_workers = idle_workers |
| | self.loop = root_task._loop |
| | self.queue: Queue[ |
| | tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None |
| | ] = Queue(2) |
| | self.idle_since = AsyncIOBackend.current_time() |
| | self.stopping = False |
| |
|
| | def _report_result( |
| | self, future: asyncio.Future, result: Any, exc: BaseException | None |
| | ) -> None: |
| | self.idle_since = AsyncIOBackend.current_time() |
| | if not self.stopping: |
| | self.idle_workers.append(self) |
| |
|
| | if not future.cancelled(): |
| | if exc is not None: |
| | if isinstance(exc, StopIteration): |
| | new_exc = RuntimeError("coroutine raised StopIteration") |
| | new_exc.__cause__ = exc |
| | exc = new_exc |
| |
|
| | future.set_exception(exc) |
| | else: |
| | future.set_result(result) |
| |
|
| | def run(self) -> None: |
| | with claim_worker_thread(AsyncIOBackend, self.loop): |
| | while True: |
| | item = self.queue.get() |
| | if item is None: |
| | |
| | return |
| |
|
| | context, func, args, future, cancel_scope = item |
| | if not future.cancelled(): |
| | result = None |
| | exception: BaseException | None = None |
| | threadlocals.current_cancel_scope = cancel_scope |
| | try: |
| | result = context.run(func, *args) |
| | except BaseException as exc: |
| | exception = exc |
| | finally: |
| | del threadlocals.current_cancel_scope |
| |
|
| | if not self.loop.is_closed(): |
| | self.loop.call_soon_threadsafe( |
| | self._report_result, future, result, exception |
| | ) |
| |
|
| | del result, exception |
| |
|
| | self.queue.task_done() |
| | del item, context, func, args, future, cancel_scope |
| |
|
| | def stop(self, f: asyncio.Task | None = None) -> None: |
| | self.stopping = True |
| | self.queue.put_nowait(None) |
| | self.workers.discard(self) |
| | try: |
| | self.idle_workers.remove(self) |
| | except ValueError: |
| | pass |
| |
|
| |
|
| | _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( |
| | "_threadpool_idle_workers" |
| | ) |
| | _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class StreamReaderWrapper(abc.ByteReceiveStream): |
| | _stream: asyncio.StreamReader |
| |
|
| | async def receive(self, max_bytes: int = 65536) -> bytes: |
| | data = await self._stream.read(max_bytes) |
| | if data: |
| | return data |
| | else: |
| | raise EndOfStream |
| |
|
| | async def aclose(self) -> None: |
| | self._stream.set_exception(ClosedResourceError()) |
| | await AsyncIOBackend.checkpoint() |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class StreamWriterWrapper(abc.ByteSendStream): |
| | _stream: asyncio.StreamWriter |
| | _closed: bool = field(init=False, default=False) |
| |
|
| | async def send(self, item: bytes) -> None: |
| | await AsyncIOBackend.checkpoint_if_cancelled() |
| | stream_paused = self._stream._protocol._paused |
| | try: |
| | self._stream.write(item) |
| | await self._stream.drain() |
| | except (ConnectionResetError, BrokenPipeError, RuntimeError) as exc: |
| | |
| | |
| | |
| | if self._closed: |
| | raise ClosedResourceError from exc |
| | elif self._stream.is_closing(): |
| | raise BrokenResourceError from exc |
| |
|
| | raise |
| |
|
| | if not stream_paused: |
| | await AsyncIOBackend.cancel_shielded_checkpoint() |
| |
|
| | async def aclose(self) -> None: |
| | self._closed = True |
| | self._stream.close() |
| | await AsyncIOBackend.checkpoint() |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class Process(abc.Process): |
| | _process: asyncio.subprocess.Process |
| | _stdin: StreamWriterWrapper | None |
| | _stdout: StreamReaderWrapper | None |
| | _stderr: StreamReaderWrapper | None |
| |
|
| | async def aclose(self) -> None: |
| | with CancelScope(shield=True) as scope: |
| | if self._stdin: |
| | await self._stdin.aclose() |
| | if self._stdout: |
| | await self._stdout.aclose() |
| | if self._stderr: |
| | await self._stderr.aclose() |
| |
|
| | scope.shield = False |
| | try: |
| | await self.wait() |
| | except BaseException: |
| | scope.shield = True |
| | self.kill() |
| | await self.wait() |
| | raise |
| |
|
| | async def wait(self) -> int: |
| | return await self._process.wait() |
| |
|
| | def terminate(self) -> None: |
| | self._process.terminate() |
| |
|
| | def kill(self) -> None: |
| | self._process.kill() |
| |
|
| | def send_signal(self, signal: int) -> None: |
| | self._process.send_signal(signal) |
| |
|
| | @property |
| | def pid(self) -> int: |
| | return self._process.pid |
| |
|
| | @property |
| | def returncode(self) -> int | None: |
| | return self._process.returncode |
| |
|
| | @property |
| | def stdin(self) -> abc.ByteSendStream | None: |
| | return self._stdin |
| |
|
| | @property |
| | def stdout(self) -> abc.ByteReceiveStream | None: |
| | return self._stdout |
| |
|
| | @property |
| | def stderr(self) -> abc.ByteReceiveStream | None: |
| | return self._stderr |
| |
|
| |
|
| | def _forcibly_shutdown_process_pool_on_exit( |
| | workers: set[Process], _task: object |
| | ) -> None: |
| | """ |
| | Forcibly shuts down worker processes belonging to this event loop.""" |
| | child_watcher: asyncio.AbstractChildWatcher | None = None |
| | if sys.version_info < (3, 12): |
| | try: |
| | child_watcher = asyncio.get_event_loop_policy().get_child_watcher() |
| | except NotImplementedError: |
| | pass |
| |
|
| | |
| | for process in workers.copy(): |
| | if process.returncode is None: |
| | continue |
| |
|
| | process._stdin._stream._transport.close() |
| | process._stdout._stream._transport.close() |
| | process._stderr._stream._transport.close() |
| | process.kill() |
| | if child_watcher: |
| | child_watcher.remove_child_handler(process.pid) |
| |
|
| |
|
| | async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None: |
| | """ |
| | Shuts down worker processes belonging to this event loop. |
| | |
| | NOTE: this only works when the event loop was started using asyncio.run() or |
| | anyio.run(). |
| | |
| | """ |
| | process: abc.Process |
| | try: |
| | await sleep(math.inf) |
| | except asyncio.CancelledError: |
| | workers = workers.copy() |
| | for process in workers: |
| | if process.returncode is None: |
| | process.kill() |
| |
|
| | for process in workers: |
| | await process.aclose() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class StreamProtocol(asyncio.Protocol): |
| | read_queue: deque[bytes] |
| | read_event: asyncio.Event |
| | write_event: asyncio.Event |
| | exception: Exception | None = None |
| | is_at_eof: bool = False |
| |
|
| | def connection_made(self, transport: asyncio.BaseTransport) -> None: |
| | self.read_queue = deque() |
| | self.read_event = asyncio.Event() |
| | self.write_event = asyncio.Event() |
| | self.write_event.set() |
| | cast(asyncio.Transport, transport).set_write_buffer_limits(0) |
| |
|
| | def connection_lost(self, exc: Exception | None) -> None: |
| | if exc: |
| | self.exception = BrokenResourceError() |
| | self.exception.__cause__ = exc |
| |
|
| | self.read_event.set() |
| | self.write_event.set() |
| |
|
| | def data_received(self, data: bytes) -> None: |
| | |
| | self.read_queue.append(bytes(data)) |
| | self.read_event.set() |
| |
|
| | def eof_received(self) -> bool | None: |
| | self.is_at_eof = True |
| | self.read_event.set() |
| | return True |
| |
|
| | def pause_writing(self) -> None: |
| | self.write_event = asyncio.Event() |
| |
|
| | def resume_writing(self) -> None: |
| | self.write_event.set() |
| |
|
| |
|
| | class DatagramProtocol(asyncio.DatagramProtocol): |
| | read_queue: deque[tuple[bytes, IPSockAddrType]] |
| | read_event: asyncio.Event |
| | write_event: asyncio.Event |
| | exception: Exception | None = None |
| |
|
| | def connection_made(self, transport: asyncio.BaseTransport) -> None: |
| | self.read_queue = deque(maxlen=100) |
| | self.read_event = asyncio.Event() |
| | self.write_event = asyncio.Event() |
| | self.write_event.set() |
| |
|
| | def connection_lost(self, exc: Exception | None) -> None: |
| | self.read_event.set() |
| | self.write_event.set() |
| |
|
| | def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: |
| | addr = convert_ipv6_sockaddr(addr) |
| | self.read_queue.append((data, addr)) |
| | self.read_event.set() |
| |
|
| | def error_received(self, exc: Exception) -> None: |
| | self.exception = exc |
| |
|
| | def pause_writing(self) -> None: |
| | self.write_event.clear() |
| |
|
| | def resume_writing(self) -> None: |
| | self.write_event.set() |
| |
|
| |
|
| | class SocketStream(abc.SocketStream): |
| | def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): |
| | self._transport = transport |
| | self._protocol = protocol |
| | self._receive_guard = ResourceGuard("reading from") |
| | self._send_guard = ResourceGuard("writing to") |
| | self._closed = False |
| |
|
| | @property |
| | def _raw_socket(self) -> socket.socket: |
| | return self._transport.get_extra_info("socket") |
| |
|
| | async def receive(self, max_bytes: int = 65536) -> bytes: |
| | with self._receive_guard: |
| | if ( |
| | not self._protocol.read_event.is_set() |
| | and not self._transport.is_closing() |
| | and not self._protocol.is_at_eof |
| | ): |
| | self._transport.resume_reading() |
| | await self._protocol.read_event.wait() |
| | self._transport.pause_reading() |
| | else: |
| | await AsyncIOBackend.checkpoint() |
| |
|
| | try: |
| | chunk = self._protocol.read_queue.popleft() |
| | except IndexError: |
| | if self._closed: |
| | raise ClosedResourceError from None |
| | elif self._protocol.exception: |
| | raise self._protocol.exception from None |
| | else: |
| | raise EndOfStream from None |
| |
|
| | if len(chunk) > max_bytes: |
| | |
| | chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] |
| | self._protocol.read_queue.appendleft(leftover) |
| |
|
| | |
| | |
| | if not self._protocol.read_queue: |
| | self._protocol.read_event.clear() |
| |
|
| | return chunk |
| |
|
| | async def send(self, item: bytes) -> None: |
| | with self._send_guard: |
| | await AsyncIOBackend.checkpoint() |
| |
|
| | if self._closed: |
| | raise ClosedResourceError |
| | elif self._protocol.exception is not None: |
| | raise self._protocol.exception |
| |
|
| | try: |
| | self._transport.write(item) |
| | except RuntimeError as exc: |
| | if self._transport.is_closing(): |
| | raise BrokenResourceError from exc |
| | else: |
| | raise |
| |
|
| | await self._protocol.write_event.wait() |
| |
|
| | async def send_eof(self) -> None: |
| | try: |
| | self._transport.write_eof() |
| | except OSError: |
| | pass |
| |
|
| | async def aclose(self) -> None: |
| | self._closed = True |
| | if not self._transport.is_closing(): |
| | try: |
| | self._transport.write_eof() |
| | except OSError: |
| | pass |
| |
|
| | self._transport.close() |
| | await sleep(0) |
| | self._transport.abort() |
| |
|
| |
|
| | class _RawSocketMixin: |
| | _receive_future: asyncio.Future | None = None |
| | _send_future: asyncio.Future | None = None |
| | _closing = False |
| |
|
| | def __init__(self, raw_socket: socket.socket): |
| | self.__raw_socket = raw_socket |
| | self._receive_guard = ResourceGuard("reading from") |
| | self._send_guard = ResourceGuard("writing to") |
| |
|
| | @property |
| | def _raw_socket(self) -> socket.socket: |
| | return self.__raw_socket |
| |
|
| | def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: |
| | def callback(f: object) -> None: |
| | del self._receive_future |
| | loop.remove_reader(self.__raw_socket) |
| |
|
| | f = self._receive_future = asyncio.Future() |
| | loop.add_reader(self.__raw_socket, f.set_result, None) |
| | f.add_done_callback(callback) |
| | return f |
| |
|
| | def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: |
| | def callback(f: object) -> None: |
| | del self._send_future |
| | loop.remove_writer(self.__raw_socket) |
| |
|
| | f = self._send_future = asyncio.Future() |
| | loop.add_writer(self.__raw_socket, f.set_result, None) |
| | f.add_done_callback(callback) |
| | return f |
| |
|
| | async def aclose(self) -> None: |
| | if not self._closing: |
| | self._closing = True |
| | if self.__raw_socket.fileno() != -1: |
| | self.__raw_socket.close() |
| |
|
| | if self._receive_future: |
| | self._receive_future.set_result(None) |
| | if self._send_future: |
| | self._send_future.set_result(None) |
| |
|
| |
|
| | class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream): |
| | async def send_eof(self) -> None: |
| | with self._send_guard: |
| | self._raw_socket.shutdown(socket.SHUT_WR) |
| |
|
| | async def receive(self, max_bytes: int = 65536) -> bytes: |
| | loop = get_running_loop() |
| | await AsyncIOBackend.checkpoint() |
| | with self._receive_guard: |
| | while True: |
| | try: |
| | data = self._raw_socket.recv(max_bytes) |
| | except BlockingIOError: |
| | await self._wait_until_readable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | if not data: |
| | raise EndOfStream |
| |
|
| | return data |
| |
|
| | async def send(self, item: bytes) -> None: |
| | loop = get_running_loop() |
| | await AsyncIOBackend.checkpoint() |
| | with self._send_guard: |
| | view = memoryview(item) |
| | while view: |
| | try: |
| | bytes_sent = self._raw_socket.send(view) |
| | except BlockingIOError: |
| | await self._wait_until_writable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | view = view[bytes_sent:] |
| |
|
| | async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: |
| | if not isinstance(msglen, int) or msglen < 0: |
| | raise ValueError("msglen must be a non-negative integer") |
| | if not isinstance(maxfds, int) or maxfds < 1: |
| | raise ValueError("maxfds must be a positive integer") |
| |
|
| | loop = get_running_loop() |
| | fds = array.array("i") |
| | await AsyncIOBackend.checkpoint() |
| | with self._receive_guard: |
| | while True: |
| | try: |
| | message, ancdata, flags, addr = self._raw_socket.recvmsg( |
| | msglen, socket.CMSG_LEN(maxfds * fds.itemsize) |
| | ) |
| | except BlockingIOError: |
| | await self._wait_until_readable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | if not message and not ancdata: |
| | raise EndOfStream |
| |
|
| | break |
| |
|
| | for cmsg_level, cmsg_type, cmsg_data in ancdata: |
| | if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: |
| | raise RuntimeError( |
| | f"Received unexpected ancillary data; message = {message!r}, " |
| | f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" |
| | ) |
| |
|
| | fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) |
| |
|
| | return message, list(fds) |
| |
|
| | async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: |
| | if not message: |
| | raise ValueError("message must not be empty") |
| | if not fds: |
| | raise ValueError("fds must not be empty") |
| |
|
| | loop = get_running_loop() |
| | filenos: list[int] = [] |
| | for fd in fds: |
| | if isinstance(fd, int): |
| | filenos.append(fd) |
| | elif isinstance(fd, IOBase): |
| | filenos.append(fd.fileno()) |
| |
|
| | fdarray = array.array("i", filenos) |
| | await AsyncIOBackend.checkpoint() |
| | with self._send_guard: |
| | while True: |
| | try: |
| | |
| | |
| | self._raw_socket.sendmsg( |
| | [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] |
| | ) |
| | break |
| | except BlockingIOError: |
| | await self._wait_until_writable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| |
|
| |
|
| | class TCPSocketListener(abc.SocketListener): |
| | _accept_scope: CancelScope | None = None |
| | _closed = False |
| |
|
| | def __init__(self, raw_socket: socket.socket): |
| | self.__raw_socket = raw_socket |
| | self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) |
| | self._accept_guard = ResourceGuard("accepting connections from") |
| |
|
| | @property |
| | def _raw_socket(self) -> socket.socket: |
| | return self.__raw_socket |
| |
|
| | async def accept(self) -> abc.SocketStream: |
| | if self._closed: |
| | raise ClosedResourceError |
| |
|
| | with self._accept_guard: |
| | await AsyncIOBackend.checkpoint() |
| | with CancelScope() as self._accept_scope: |
| | try: |
| | client_sock, _addr = await self._loop.sock_accept(self._raw_socket) |
| | except asyncio.CancelledError: |
| | |
| | try: |
| | self._loop.remove_reader(self._raw_socket) |
| | except (ValueError, NotImplementedError): |
| | pass |
| |
|
| | if self._closed: |
| | raise ClosedResourceError from None |
| |
|
| | raise |
| | finally: |
| | self._accept_scope = None |
| |
|
| | client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
| | transport, protocol = await self._loop.connect_accepted_socket( |
| | StreamProtocol, client_sock |
| | ) |
| | return SocketStream(transport, protocol) |
| |
|
| | async def aclose(self) -> None: |
| | if self._closed: |
| | return |
| |
|
| | self._closed = True |
| | if self._accept_scope: |
| | |
| | try: |
| | self._loop.remove_reader(self._raw_socket) |
| | except (ValueError, NotImplementedError): |
| | pass |
| |
|
| | self._accept_scope.cancel() |
| | await sleep(0) |
| |
|
| | self._raw_socket.close() |
| |
|
| |
|
| | class UNIXSocketListener(abc.SocketListener): |
| | def __init__(self, raw_socket: socket.socket): |
| | self.__raw_socket = raw_socket |
| | self._loop = get_running_loop() |
| | self._accept_guard = ResourceGuard("accepting connections from") |
| | self._closed = False |
| |
|
| | async def accept(self) -> abc.SocketStream: |
| | await AsyncIOBackend.checkpoint() |
| | with self._accept_guard: |
| | while True: |
| | try: |
| | client_sock, _ = self.__raw_socket.accept() |
| | client_sock.setblocking(False) |
| | return UNIXSocketStream(client_sock) |
| | except BlockingIOError: |
| | f: asyncio.Future = asyncio.Future() |
| | self._loop.add_reader(self.__raw_socket, f.set_result, None) |
| | f.add_done_callback( |
| | lambda _: self._loop.remove_reader(self.__raw_socket) |
| | ) |
| | await f |
| | except OSError as exc: |
| | if self._closed: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| |
|
| | async def aclose(self) -> None: |
| | self._closed = True |
| | self.__raw_socket.close() |
| |
|
| | @property |
| | def _raw_socket(self) -> socket.socket: |
| | return self.__raw_socket |
| |
|
| |
|
| | class UDPSocket(abc.UDPSocket): |
| | def __init__( |
| | self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol |
| | ): |
| | self._transport = transport |
| | self._protocol = protocol |
| | self._receive_guard = ResourceGuard("reading from") |
| | self._send_guard = ResourceGuard("writing to") |
| | self._closed = False |
| |
|
| | @property |
| | def _raw_socket(self) -> socket.socket: |
| | return self._transport.get_extra_info("socket") |
| |
|
| | async def aclose(self) -> None: |
| | self._closed = True |
| | if not self._transport.is_closing(): |
| | self._transport.close() |
| |
|
| | async def receive(self) -> tuple[bytes, IPSockAddrType]: |
| | with self._receive_guard: |
| | await AsyncIOBackend.checkpoint() |
| |
|
| | |
| | if not self._protocol.read_queue and not self._transport.is_closing(): |
| | self._protocol.read_event.clear() |
| | await self._protocol.read_event.wait() |
| |
|
| | try: |
| | return self._protocol.read_queue.popleft() |
| | except IndexError: |
| | if self._closed: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from None |
| |
|
| | async def send(self, item: UDPPacketType) -> None: |
| | with self._send_guard: |
| | await AsyncIOBackend.checkpoint() |
| | await self._protocol.write_event.wait() |
| | if self._closed: |
| | raise ClosedResourceError |
| | elif self._transport.is_closing(): |
| | raise BrokenResourceError |
| | else: |
| | self._transport.sendto(*item) |
| |
|
| |
|
| | class ConnectedUDPSocket(abc.ConnectedUDPSocket): |
| | def __init__( |
| | self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol |
| | ): |
| | self._transport = transport |
| | self._protocol = protocol |
| | self._receive_guard = ResourceGuard("reading from") |
| | self._send_guard = ResourceGuard("writing to") |
| | self._closed = False |
| |
|
| | @property |
| | def _raw_socket(self) -> socket.socket: |
| | return self._transport.get_extra_info("socket") |
| |
|
| | async def aclose(self) -> None: |
| | self._closed = True |
| | if not self._transport.is_closing(): |
| | self._transport.close() |
| |
|
| | async def receive(self) -> bytes: |
| | with self._receive_guard: |
| | await AsyncIOBackend.checkpoint() |
| |
|
| | |
| | if not self._protocol.read_queue and not self._transport.is_closing(): |
| | self._protocol.read_event.clear() |
| | await self._protocol.read_event.wait() |
| |
|
| | try: |
| | packet = self._protocol.read_queue.popleft() |
| | except IndexError: |
| | if self._closed: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from None |
| |
|
| | return packet[0] |
| |
|
| | async def send(self, item: bytes) -> None: |
| | with self._send_guard: |
| | await AsyncIOBackend.checkpoint() |
| | await self._protocol.write_event.wait() |
| | if self._closed: |
| | raise ClosedResourceError |
| | elif self._transport.is_closing(): |
| | raise BrokenResourceError |
| | else: |
| | self._transport.sendto(item) |
| |
|
| |
|
| | class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket): |
| | async def receive(self) -> UNIXDatagramPacketType: |
| | loop = get_running_loop() |
| | await AsyncIOBackend.checkpoint() |
| | with self._receive_guard: |
| | while True: |
| | try: |
| | data = self._raw_socket.recvfrom(65536) |
| | except BlockingIOError: |
| | await self._wait_until_readable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | return data |
| |
|
| | async def send(self, item: UNIXDatagramPacketType) -> None: |
| | loop = get_running_loop() |
| | await AsyncIOBackend.checkpoint() |
| | with self._send_guard: |
| | while True: |
| | try: |
| | self._raw_socket.sendto(*item) |
| | except BlockingIOError: |
| | await self._wait_until_writable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | return |
| |
|
| |
|
| | class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket): |
| | async def receive(self) -> bytes: |
| | loop = get_running_loop() |
| | await AsyncIOBackend.checkpoint() |
| | with self._receive_guard: |
| | while True: |
| | try: |
| | data = self._raw_socket.recv(65536) |
| | except BlockingIOError: |
| | await self._wait_until_readable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | return data |
| |
|
| | async def send(self, item: bytes) -> None: |
| | loop = get_running_loop() |
| | await AsyncIOBackend.checkpoint() |
| | with self._send_guard: |
| | while True: |
| | try: |
| | self._raw_socket.send(item) |
| | except BlockingIOError: |
| | await self._wait_until_writable(loop) |
| | except OSError as exc: |
| | if self._closing: |
| | raise ClosedResourceError from None |
| | else: |
| | raise BrokenResourceError from exc |
| | else: |
| | return |
| |
|
| |
|
| | _read_events: RunVar[dict[int, asyncio.Future[bool]]] = RunVar("read_events") |
| | _write_events: RunVar[dict[int, asyncio.Future[bool]]] = RunVar("write_events") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class Event(BaseEvent): |
| | def __new__(cls) -> Event: |
| | return object.__new__(cls) |
| |
|
| | def __init__(self) -> None: |
| | self._event = asyncio.Event() |
| |
|
| | def set(self) -> None: |
| | self._event.set() |
| |
|
| | def is_set(self) -> bool: |
| | return self._event.is_set() |
| |
|
| | async def wait(self) -> None: |
| | if self.is_set(): |
| | await AsyncIOBackend.checkpoint() |
| | else: |
| | await self._event.wait() |
| |
|
| | def statistics(self) -> EventStatistics: |
| | return EventStatistics(len(self._event._waiters)) |
| |
|
| |
|
| | class Lock(BaseLock): |
| | def __new__(cls, *, fast_acquire: bool = False) -> Lock: |
| | return object.__new__(cls) |
| |
|
| | def __init__(self, *, fast_acquire: bool = False) -> None: |
| | self._fast_acquire = fast_acquire |
| | self._owner_task: asyncio.Task | None = None |
| | self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque() |
| |
|
| | async def acquire(self) -> None: |
| | task = cast(asyncio.Task, current_task()) |
| | if self._owner_task is None and not self._waiters: |
| | await AsyncIOBackend.checkpoint_if_cancelled() |
| | self._owner_task = task |
| |
|
| | |
| | |
| | if not self._fast_acquire: |
| | try: |
| | await AsyncIOBackend.cancel_shielded_checkpoint() |
| | except CancelledError: |
| | self.release() |
| | raise |
| |
|
| | return |
| |
|
| | if self._owner_task == task: |
| | raise RuntimeError("Attempted to acquire an already held Lock") |
| |
|
| | fut: asyncio.Future[None] = asyncio.Future() |
| | item = task, fut |
| | self._waiters.append(item) |
| | try: |
| | await fut |
| | except CancelledError: |
| | self._waiters.remove(item) |
| | if self._owner_task is task: |
| | self.release() |
| |
|
| | raise |
| |
|
| | self._waiters.remove(item) |
| |
|
| | def acquire_nowait(self) -> None: |
| | task = cast(asyncio.Task, current_task()) |
| | if self._owner_task is None and not self._waiters: |
| | self._owner_task = task |
| | return |
| |
|
| | if self._owner_task is task: |
| | raise RuntimeError("Attempted to acquire an already held Lock") |
| |
|
| | raise WouldBlock |
| |
|
| | def locked(self) -> bool: |
| | return self._owner_task is not None |
| |
|
| | def release(self) -> None: |
| | if self._owner_task != current_task(): |
| | raise RuntimeError("The current task is not holding this lock") |
| |
|
| | for task, fut in self._waiters: |
| | if not fut.cancelled(): |
| | self._owner_task = task |
| | fut.set_result(None) |
| | return |
| |
|
| | self._owner_task = None |
| |
|
| | def statistics(self) -> LockStatistics: |
| | task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None |
| | return LockStatistics(self.locked(), task_info, len(self._waiters)) |
| |
|
| |
|
| | class Semaphore(BaseSemaphore): |
| | def __new__( |
| | cls, |
| | initial_value: int, |
| | *, |
| | max_value: int | None = None, |
| | fast_acquire: bool = False, |
| | ) -> Semaphore: |
| | return object.__new__(cls) |
| |
|
| | def __init__( |
| | self, |
| | initial_value: int, |
| | *, |
| | max_value: int | None = None, |
| | fast_acquire: bool = False, |
| | ): |
| | super().__init__(initial_value, max_value=max_value) |
| | self._value = initial_value |
| | self._max_value = max_value |
| | self._fast_acquire = fast_acquire |
| | self._waiters: deque[asyncio.Future[None]] = deque() |
| |
|
| | async def acquire(self) -> None: |
| | if self._value > 0 and not self._waiters: |
| | await AsyncIOBackend.checkpoint_if_cancelled() |
| | self._value -= 1 |
| |
|
| | |
| | |
| | if not self._fast_acquire: |
| | try: |
| | await AsyncIOBackend.cancel_shielded_checkpoint() |
| | except CancelledError: |
| | self.release() |
| | raise |
| |
|
| | return |
| |
|
| | fut: asyncio.Future[None] = asyncio.Future() |
| | self._waiters.append(fut) |
| | try: |
| | await fut |
| | except CancelledError: |
| | try: |
| | self._waiters.remove(fut) |
| | except ValueError: |
| | self.release() |
| |
|
| | raise |
| |
|
| | def acquire_nowait(self) -> None: |
| | if self._value == 0: |
| | raise WouldBlock |
| |
|
| | self._value -= 1 |
| |
|
| | def release(self) -> None: |
| | if self._max_value is not None and self._value == self._max_value: |
| | raise ValueError("semaphore released too many times") |
| |
|
| | for fut in self._waiters: |
| | if not fut.cancelled(): |
| | fut.set_result(None) |
| | self._waiters.remove(fut) |
| | return |
| |
|
| | self._value += 1 |
| |
|
| | @property |
| | def value(self) -> int: |
| | return self._value |
| |
|
| | @property |
| | def max_value(self) -> int | None: |
| | return self._max_value |
| |
|
| | def statistics(self) -> SemaphoreStatistics: |
| | return SemaphoreStatistics(len(self._waiters)) |
| |
|
| |
|
| | class CapacityLimiter(BaseCapacityLimiter): |
| | _total_tokens: float = 0 |
| |
|
| | def __new__(cls, total_tokens: float) -> CapacityLimiter: |
| | return object.__new__(cls) |
| |
|
| | def __init__(self, total_tokens: float): |
| | self._borrowers: set[Any] = set() |
| | self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() |
| | self.total_tokens = total_tokens |
| |
|
| | async def __aenter__(self) -> None: |
| | await self.acquire() |
| |
|
| | async def __aexit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: TracebackType | None, |
| | ) -> None: |
| | self.release() |
| |
|
| | @property |
| | def total_tokens(self) -> float: |
| | return self._total_tokens |
| |
|
| | @total_tokens.setter |
| | def total_tokens(self, value: float) -> None: |
| | if not isinstance(value, int) and not math.isinf(value): |
| | raise TypeError("total_tokens must be an int or math.inf") |
| |
|
| | if value < 0: |
| | raise ValueError("total_tokens must be >= 0") |
| |
|
| | waiters_to_notify = max(value - self._total_tokens, 0) |
| | self._total_tokens = value |
| |
|
| | |
| | while self._wait_queue and waiters_to_notify: |
| | event = self._wait_queue.popitem(last=False)[1] |
| | event.set() |
| | waiters_to_notify -= 1 |
| |
|
| | @property |
| | def borrowed_tokens(self) -> int: |
| | return len(self._borrowers) |
| |
|
| | @property |
| | def available_tokens(self) -> float: |
| | return self._total_tokens - len(self._borrowers) |
| |
|
| | def _notify_next_waiter(self) -> None: |
| | """Notify the next task in line if this limiter has free capacity now.""" |
| | if self._wait_queue and len(self._borrowers) < self._total_tokens: |
| | event = self._wait_queue.popitem(last=False)[1] |
| | event.set() |
| |
|
| | def acquire_nowait(self) -> None: |
| | self.acquire_on_behalf_of_nowait(current_task()) |
| |
|
| | def acquire_on_behalf_of_nowait(self, borrower: object) -> None: |
| | if borrower in self._borrowers: |
| | raise RuntimeError( |
| | "this borrower is already holding one of this CapacityLimiter's tokens" |
| | ) |
| |
|
| | if self._wait_queue or len(self._borrowers) >= self._total_tokens: |
| | raise WouldBlock |
| |
|
| | self._borrowers.add(borrower) |
| |
|
| | async def acquire(self) -> None: |
| | return await self.acquire_on_behalf_of(current_task()) |
| |
|
| | async def acquire_on_behalf_of(self, borrower: object) -> None: |
| | await AsyncIOBackend.checkpoint_if_cancelled() |
| | try: |
| | self.acquire_on_behalf_of_nowait(borrower) |
| | except WouldBlock: |
| | event = asyncio.Event() |
| | self._wait_queue[borrower] = event |
| | try: |
| | await event.wait() |
| | except BaseException: |
| | self._wait_queue.pop(borrower, None) |
| | if event.is_set(): |
| | self._notify_next_waiter() |
| |
|
| | raise |
| |
|
| | self._borrowers.add(borrower) |
| | else: |
| | try: |
| | await AsyncIOBackend.cancel_shielded_checkpoint() |
| | except BaseException: |
| | self.release() |
| | raise |
| |
|
| | def release(self) -> None: |
| | self.release_on_behalf_of(current_task()) |
| |
|
| | def release_on_behalf_of(self, borrower: object) -> None: |
| | try: |
| | self._borrowers.remove(borrower) |
| | except KeyError: |
| | raise RuntimeError( |
| | "this borrower isn't holding any of this CapacityLimiter's tokens" |
| | ) from None |
| |
|
| | self._notify_next_waiter() |
| |
|
| | def statistics(self) -> CapacityLimiterStatistics: |
| | return CapacityLimiterStatistics( |
| | self.borrowed_tokens, |
| | self.total_tokens, |
| | tuple(self._borrowers), |
| | len(self._wait_queue), |
| | ) |
| |
|
| |
|
| | _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class _SignalReceiver: |
| | def __init__(self, signals: tuple[Signals, ...]): |
| | self._signals = signals |
| | self._loop = get_running_loop() |
| | self._signal_queue: deque[Signals] = deque() |
| | self._future: asyncio.Future = asyncio.Future() |
| | self._handled_signals: set[Signals] = set() |
| |
|
| | def _deliver(self, signum: Signals) -> None: |
| | self._signal_queue.append(signum) |
| | if not self._future.done(): |
| | self._future.set_result(None) |
| |
|
| | def __enter__(self) -> _SignalReceiver: |
| | for sig in set(self._signals): |
| | self._loop.add_signal_handler(sig, self._deliver, sig) |
| | self._handled_signals.add(sig) |
| |
|
| | return self |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: TracebackType | None, |
| | ) -> None: |
| | for sig in self._handled_signals: |
| | self._loop.remove_signal_handler(sig) |
| |
|
| | def __aiter__(self) -> _SignalReceiver: |
| | return self |
| |
|
| | async def __anext__(self) -> Signals: |
| | await AsyncIOBackend.checkpoint() |
| | if not self._signal_queue: |
| | self._future = asyncio.Future() |
| | await self._future |
| |
|
| | return self._signal_queue.popleft() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class AsyncIOTaskInfo(TaskInfo): |
| | def __init__(self, task: asyncio.Task): |
| | task_state = _task_states.get(task) |
| | if task_state is None: |
| | parent_id = None |
| | else: |
| | parent_id = task_state.parent_id |
| |
|
| | coro = task.get_coro() |
| | assert coro is not None, "created TaskInfo from a completed Task" |
| | super().__init__(id(task), parent_id, task.get_name(), coro) |
| | self._task = weakref.ref(task) |
| |
|
| | def has_pending_cancellation(self) -> bool: |
| | if not (task := self._task()): |
| | |
| | return False |
| |
|
| | if task._must_cancel: |
| | return True |
| | elif ( |
| | isinstance(task._fut_waiter, asyncio.Future) |
| | and task._fut_waiter.cancelled() |
| | ): |
| | return True |
| |
|
| | if task_state := _task_states.get(task): |
| | if cancel_scope := task_state.cancel_scope: |
| | return cancel_scope._effectively_cancelled |
| |
|
| | return False |
| |
|
| |
|
| | class TestRunner(abc.TestRunner): |
| | _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]] |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | debug: bool | None = None, |
| | use_uvloop: bool = False, |
| | loop_factory: Callable[[], AbstractEventLoop] | None = None, |
| | ) -> None: |
| | if use_uvloop and loop_factory is None: |
| | if sys.platform != "win32": |
| | import uvloop |
| |
|
| | loop_factory = uvloop.new_event_loop |
| | else: |
| | import winloop |
| |
|
| | loop_factory = winloop.new_event_loop |
| |
|
| | self._runner = Runner(debug=debug, loop_factory=loop_factory) |
| | self._exceptions: list[BaseException] = [] |
| | self._runner_task: asyncio.Task | None = None |
| |
|
| | def __enter__(self) -> TestRunner: |
| | self._runner.__enter__() |
| | self.get_loop().set_exception_handler(self._exception_handler) |
| | return self |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: TracebackType | None, |
| | ) -> None: |
| | self._runner.__exit__(exc_type, exc_val, exc_tb) |
| |
|
| | def get_loop(self) -> AbstractEventLoop: |
| | return self._runner.get_loop() |
| |
|
| | def _exception_handler( |
| | self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] |
| | ) -> None: |
| | if isinstance(context.get("exception"), Exception): |
| | self._exceptions.append(context["exception"]) |
| | else: |
| | loop.default_exception_handler(context) |
| |
|
| | def _raise_async_exceptions(self) -> None: |
| | |
| | if self._exceptions: |
| | exceptions, self._exceptions = self._exceptions, [] |
| | if len(exceptions) == 1: |
| | raise exceptions[0] |
| | elif exceptions: |
| | raise BaseExceptionGroup( |
| | "Multiple exceptions occurred in asynchronous callbacks", exceptions |
| | ) |
| |
|
| | async def _run_tests_and_fixtures( |
| | self, |
| | receive_stream: MemoryObjectReceiveStream[ |
| | tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]] |
| | ], |
| | ) -> None: |
| | from _pytest.outcomes import OutcomeException |
| |
|
| | with receive_stream, self._send_stream: |
| | async for coro, future in receive_stream: |
| | try: |
| | retval = await coro |
| | except CancelledError as exc: |
| | if not future.cancelled(): |
| | future.cancel(*exc.args) |
| |
|
| | raise |
| | except BaseException as exc: |
| | if not future.cancelled(): |
| | future.set_exception(exc) |
| |
|
| | if not isinstance(exc, (Exception, OutcomeException)): |
| | raise |
| | else: |
| | if not future.cancelled(): |
| | future.set_result(retval) |
| |
|
| | async def _call_in_runner_task( |
| | self, |
| | func: Callable[P, Awaitable[T_Retval]], |
| | *args: P.args, |
| | **kwargs: P.kwargs, |
| | ) -> T_Retval: |
| | if not self._runner_task: |
| | self._send_stream, receive_stream = create_memory_object_stream[ |
| | tuple[Awaitable[Any], asyncio.Future] |
| | ](1) |
| | self._runner_task = self.get_loop().create_task( |
| | self._run_tests_and_fixtures(receive_stream) |
| | ) |
| |
|
| | coro = func(*args, **kwargs) |
| | future: asyncio.Future[T_Retval] = self.get_loop().create_future() |
| | self._send_stream.send_nowait((coro, future)) |
| | return await future |
| |
|
| | def run_asyncgen_fixture( |
| | self, |
| | fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], |
| | kwargs: dict[str, Any], |
| | ) -> Iterable[T_Retval]: |
| | asyncgen = fixture_func(**kwargs) |
| | fixturevalue: T_Retval = self.get_loop().run_until_complete( |
| | self._call_in_runner_task(asyncgen.asend, None) |
| | ) |
| | self._raise_async_exceptions() |
| |
|
| | yield fixturevalue |
| |
|
| | try: |
| | self.get_loop().run_until_complete( |
| | self._call_in_runner_task(asyncgen.asend, None) |
| | ) |
| | except StopAsyncIteration: |
| | self._raise_async_exceptions() |
| | else: |
| | self.get_loop().run_until_complete(asyncgen.aclose()) |
| | raise RuntimeError("Async generator fixture did not stop") |
| |
|
| | def run_fixture( |
| | self, |
| | fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], |
| | kwargs: dict[str, Any], |
| | ) -> T_Retval: |
| | retval = self.get_loop().run_until_complete( |
| | self._call_in_runner_task(fixture_func, **kwargs) |
| | ) |
| | self._raise_async_exceptions() |
| | return retval |
| |
|
| | def run_test( |
| | self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] |
| | ) -> None: |
| | try: |
| | self.get_loop().run_until_complete( |
| | self._call_in_runner_task(test_func, **kwargs) |
| | ) |
| | except Exception as exc: |
| | self._exceptions.append(exc) |
| |
|
| | self._raise_async_exceptions() |
| |
|
| |
|
| | class AsyncIOBackend(AsyncBackend): |
| | @classmethod |
| | def run( |
| | cls, |
| | func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], |
| | args: tuple[Unpack[PosArgsT]], |
| | kwargs: dict[str, Any], |
| | options: dict[str, Any], |
| | ) -> T_Retval: |
| | @wraps(func) |
| | async def wrapper() -> T_Retval: |
| | task = cast(asyncio.Task, current_task()) |
| | task.set_name(get_callable_name(func)) |
| | _task_states[task] = TaskState(None, None) |
| |
|
| | try: |
| | return await func(*args) |
| | finally: |
| | del _task_states[task] |
| |
|
| | debug = options.get("debug", None) |
| | loop_factory = options.get("loop_factory", None) |
| | if loop_factory is None and options.get("use_uvloop", False): |
| | if sys.platform != "win32": |
| | import uvloop |
| |
|
| | loop_factory = uvloop.new_event_loop |
| | else: |
| | import winloop |
| |
|
| | loop_factory = winloop.new_event_loop |
| |
|
| | with Runner(debug=debug, loop_factory=loop_factory) as runner: |
| | return runner.run(wrapper()) |
| |
|
| | @classmethod |
| | def current_token(cls) -> object: |
| | return get_running_loop() |
| |
|
| | @classmethod |
| | def current_time(cls) -> float: |
| | return get_running_loop().time() |
| |
|
| | @classmethod |
| | def cancelled_exception_class(cls) -> type[BaseException]: |
| | return CancelledError |
| |
|
| | @classmethod |
| | async def checkpoint(cls) -> None: |
| | await sleep(0) |
| |
|
| | @classmethod |
| | async def checkpoint_if_cancelled(cls) -> None: |
| | task = current_task() |
| | if task is None: |
| | return |
| |
|
| | try: |
| | cancel_scope = _task_states[task].cancel_scope |
| | except KeyError: |
| | return |
| |
|
| | while cancel_scope: |
| | if cancel_scope.cancel_called: |
| | await sleep(0) |
| | elif cancel_scope.shield: |
| | break |
| | else: |
| | cancel_scope = cancel_scope._parent_scope |
| |
|
| | @classmethod |
| | async def cancel_shielded_checkpoint(cls) -> None: |
| | with CancelScope(shield=True): |
| | await sleep(0) |
| |
|
| | @classmethod |
| | async def sleep(cls, delay: float) -> None: |
| | await sleep(delay) |
| |
|
| | @classmethod |
| | def create_cancel_scope( |
| | cls, *, deadline: float = math.inf, shield: bool = False |
| | ) -> CancelScope: |
| | return CancelScope(deadline=deadline, shield=shield) |
| |
|
| | @classmethod |
| | def current_effective_deadline(cls) -> float: |
| | if (task := current_task()) is None: |
| | return math.inf |
| |
|
| | try: |
| | cancel_scope = _task_states[task].cancel_scope |
| | except KeyError: |
| | return math.inf |
| |
|
| | deadline = math.inf |
| | while cancel_scope: |
| | deadline = min(deadline, cancel_scope.deadline) |
| | if cancel_scope._cancel_called: |
| | deadline = -math.inf |
| | break |
| | elif cancel_scope.shield: |
| | break |
| | else: |
| | cancel_scope = cancel_scope._parent_scope |
| |
|
| | return deadline |
| |
|
| | @classmethod |
| | def create_task_group(cls) -> abc.TaskGroup: |
| | return TaskGroup() |
| |
|
| | @classmethod |
| | def create_event(cls) -> abc.Event: |
| | return Event() |
| |
|
| | @classmethod |
| | def create_lock(cls, *, fast_acquire: bool) -> abc.Lock: |
| | return Lock(fast_acquire=fast_acquire) |
| |
|
| | @classmethod |
| | def create_semaphore( |
| | cls, |
| | initial_value: int, |
| | *, |
| | max_value: int | None = None, |
| | fast_acquire: bool = False, |
| | ) -> abc.Semaphore: |
| | return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire) |
| |
|
| | @classmethod |
| | def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: |
| | return CapacityLimiter(total_tokens) |
| |
|
| | @classmethod |
| | async def run_sync_in_worker_thread( |
| | cls, |
| | func: Callable[[Unpack[PosArgsT]], T_Retval], |
| | args: tuple[Unpack[PosArgsT]], |
| | abandon_on_cancel: bool = False, |
| | limiter: abc.CapacityLimiter | None = None, |
| | ) -> T_Retval: |
| | await cls.checkpoint() |
| |
|
| | |
| | |
| | try: |
| | idle_workers = _threadpool_idle_workers.get() |
| | workers = _threadpool_workers.get() |
| | except LookupError: |
| | idle_workers = deque() |
| | workers = set() |
| | _threadpool_idle_workers.set(idle_workers) |
| | _threadpool_workers.set(workers) |
| |
|
| | async with limiter or cls.current_default_thread_limiter(): |
| | with CancelScope(shield=not abandon_on_cancel) as scope: |
| | future = asyncio.Future[T_Retval]() |
| | root_task = find_root_task() |
| | if not idle_workers: |
| | worker = WorkerThread(root_task, workers, idle_workers) |
| | worker.start() |
| | workers.add(worker) |
| | root_task.add_done_callback( |
| | worker.stop, context=contextvars.Context() |
| | ) |
| | else: |
| | worker = idle_workers.pop() |
| |
|
| | |
| | |
| | now = cls.current_time() |
| | while idle_workers: |
| | if ( |
| | now - idle_workers[0].idle_since |
| | < WorkerThread.MAX_IDLE_TIME |
| | ): |
| | break |
| |
|
| | expired_worker = idle_workers.popleft() |
| | expired_worker.root_task.remove_done_callback( |
| | expired_worker.stop |
| | ) |
| | expired_worker.stop() |
| |
|
| | context = copy_context() |
| | context.run(set_current_async_library, None) |
| | if abandon_on_cancel or scope._parent_scope is None: |
| | worker_scope = scope |
| | else: |
| | worker_scope = scope._parent_scope |
| |
|
| | worker.queue.put_nowait((context, func, args, future, worker_scope)) |
| | return await future |
| |
|
| | @classmethod |
| | def check_cancelled(cls) -> None: |
| | scope: CancelScope | None = threadlocals.current_cancel_scope |
| | while scope is not None: |
| | if scope.cancel_called: |
| | raise CancelledError(f"Cancelled by cancel scope {id(scope):x}") |
| |
|
| | if scope.shield: |
| | return |
| |
|
| | scope = scope._parent_scope |
| |
|
| | @classmethod |
| | def run_async_from_thread( |
| | cls, |
| | func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], |
| | args: tuple[Unpack[PosArgsT]], |
| | token: object, |
| | ) -> T_Retval: |
| | async def task_wrapper() -> T_Retval: |
| | __tracebackhide__ = True |
| | if scope is not None: |
| | task = cast(asyncio.Task, current_task()) |
| | _task_states[task] = TaskState(None, scope) |
| | scope._tasks.add(task) |
| | try: |
| | return await func(*args) |
| | except CancelledError as exc: |
| | raise concurrent.futures.CancelledError(str(exc)) from None |
| | finally: |
| | if scope is not None: |
| | scope._tasks.discard(task) |
| |
|
| | loop = cast( |
| | "AbstractEventLoop", token or threadlocals.current_token.native_token |
| | ) |
| | if loop.is_closed(): |
| | raise RunFinishedError |
| |
|
| | context = copy_context() |
| | context.run(set_current_async_library, "asyncio") |
| | scope = getattr(threadlocals, "current_cancel_scope", None) |
| | f: concurrent.futures.Future[T_Retval] = context.run( |
| | asyncio.run_coroutine_threadsafe, task_wrapper(), loop=loop |
| | ) |
| | return f.result() |
| |
|
| | @classmethod |
| | def run_sync_from_thread( |
| | cls, |
| | func: Callable[[Unpack[PosArgsT]], T_Retval], |
| | args: tuple[Unpack[PosArgsT]], |
| | token: object, |
| | ) -> T_Retval: |
| | @wraps(func) |
| | def wrapper() -> None: |
| | try: |
| | set_current_async_library("asyncio") |
| | f.set_result(func(*args)) |
| | except BaseException as exc: |
| | f.set_exception(exc) |
| | if not isinstance(exc, Exception): |
| | raise |
| |
|
| | loop = cast( |
| | "AbstractEventLoop", token or threadlocals.current_token.native_token |
| | ) |
| | if loop.is_closed(): |
| | raise RunFinishedError |
| |
|
| | f: concurrent.futures.Future[T_Retval] = Future() |
| | loop.call_soon_threadsafe(wrapper) |
| | return f.result() |
| |
|
| | @classmethod |
| | async def open_process( |
| | cls, |
| | command: StrOrBytesPath | Sequence[StrOrBytesPath], |
| | *, |
| | stdin: int | IO[Any] | None, |
| | stdout: int | IO[Any] | None, |
| | stderr: int | IO[Any] | None, |
| | **kwargs: Any, |
| | ) -> Process: |
| | await cls.checkpoint() |
| | if isinstance(command, PathLike): |
| | command = os.fspath(command) |
| |
|
| | if isinstance(command, (str, bytes)): |
| | process = await asyncio.create_subprocess_shell( |
| | command, |
| | stdin=stdin, |
| | stdout=stdout, |
| | stderr=stderr, |
| | **kwargs, |
| | ) |
| | else: |
| | process = await asyncio.create_subprocess_exec( |
| | *command, |
| | stdin=stdin, |
| | stdout=stdout, |
| | stderr=stderr, |
| | **kwargs, |
| | ) |
| |
|
| | stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None |
| | stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None |
| | stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None |
| | return Process(process, stdin_stream, stdout_stream, stderr_stream) |
| |
|
| | @classmethod |
| | def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None: |
| | create_task( |
| | _shutdown_process_pool_on_exit(workers), |
| | name="AnyIO process pool shutdown task", |
| | ) |
| | find_root_task().add_done_callback( |
| | partial(_forcibly_shutdown_process_pool_on_exit, workers) |
| | ) |
| |
|
| | @classmethod |
| | async def connect_tcp( |
| | cls, host: str, port: int, local_address: IPSockAddrType | None = None |
| | ) -> abc.SocketStream: |
| | transport, protocol = cast( |
| | tuple[asyncio.Transport, StreamProtocol], |
| | await get_running_loop().create_connection( |
| | StreamProtocol, host, port, local_addr=local_address |
| | ), |
| | ) |
| | transport.pause_reading() |
| | return SocketStream(transport, protocol) |
| |
|
| | @classmethod |
| | async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream: |
| | await cls.checkpoint() |
| | loop = get_running_loop() |
| | raw_socket = socket.socket(socket.AF_UNIX) |
| | raw_socket.setblocking(False) |
| | while True: |
| | try: |
| | raw_socket.connect(path) |
| | except BlockingIOError: |
| | f: asyncio.Future = asyncio.Future() |
| | loop.add_writer(raw_socket, f.set_result, None) |
| | f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) |
| | await f |
| | except BaseException: |
| | raw_socket.close() |
| | raise |
| | else: |
| | return UNIXSocketStream(raw_socket) |
| |
|
| | @classmethod |
| | def create_tcp_listener(cls, sock: socket.socket) -> SocketListener: |
| | return TCPSocketListener(sock) |
| |
|
| | @classmethod |
| | def create_unix_listener(cls, sock: socket.socket) -> SocketListener: |
| | return UNIXSocketListener(sock) |
| |
|
| | @classmethod |
| | async def create_udp_socket( |
| | cls, |
| | family: AddressFamily, |
| | local_address: IPSockAddrType | None, |
| | remote_address: IPSockAddrType | None, |
| | reuse_port: bool, |
| | ) -> UDPSocket | ConnectedUDPSocket: |
| | transport, protocol = await get_running_loop().create_datagram_endpoint( |
| | DatagramProtocol, |
| | local_addr=local_address, |
| | remote_addr=remote_address, |
| | family=family, |
| | reuse_port=reuse_port, |
| | ) |
| | if protocol.exception: |
| | transport.close() |
| | raise protocol.exception |
| |
|
| | if not remote_address: |
| | return UDPSocket(transport, protocol) |
| | else: |
| | return ConnectedUDPSocket(transport, protocol) |
| |
|
| | @classmethod |
| | async def create_unix_datagram_socket( |
| | cls, raw_socket: socket.socket, remote_path: str | bytes | None |
| | ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket: |
| | await cls.checkpoint() |
| | loop = get_running_loop() |
| |
|
| | if remote_path: |
| | while True: |
| | try: |
| | raw_socket.connect(remote_path) |
| | except BlockingIOError: |
| | f: asyncio.Future = asyncio.Future() |
| | loop.add_writer(raw_socket, f.set_result, None) |
| | f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) |
| | await f |
| | except BaseException: |
| | raw_socket.close() |
| | raise |
| | else: |
| | return ConnectedUNIXDatagramSocket(raw_socket) |
| | else: |
| | return UNIXDatagramSocket(raw_socket) |
| |
|
| | @classmethod |
| | async def getaddrinfo( |
| | cls, |
| | host: bytes | str | None, |
| | port: str | int | None, |
| | *, |
| | family: int | AddressFamily = 0, |
| | type: int | SocketKind = 0, |
| | proto: int = 0, |
| | flags: int = 0, |
| | ) -> Sequence[ |
| | tuple[ |
| | AddressFamily, |
| | SocketKind, |
| | int, |
| | str, |
| | tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes], |
| | ] |
| | ]: |
| | return await get_running_loop().getaddrinfo( |
| | host, port, family=family, type=type, proto=proto, flags=flags |
| | ) |
| |
|
| | @classmethod |
| | async def getnameinfo( |
| | cls, sockaddr: IPSockAddrType, flags: int = 0 |
| | ) -> tuple[str, str]: |
| | return await get_running_loop().getnameinfo(sockaddr, flags) |
| |
|
| | @classmethod |
| | async def wait_readable(cls, obj: FileDescriptorLike) -> None: |
| | try: |
| | read_events = _read_events.get() |
| | except LookupError: |
| | read_events = {} |
| | _read_events.set(read_events) |
| |
|
| | fd = obj if isinstance(obj, int) else obj.fileno() |
| | if read_events.get(fd): |
| | raise BusyResourceError("reading from") |
| |
|
| | loop = get_running_loop() |
| | fut: asyncio.Future[bool] = loop.create_future() |
| |
|
| | def cb() -> None: |
| | try: |
| | del read_events[fd] |
| | except KeyError: |
| | pass |
| | else: |
| | remove_reader(fd) |
| |
|
| | try: |
| | fut.set_result(True) |
| | except asyncio.InvalidStateError: |
| | pass |
| |
|
| | try: |
| | loop.add_reader(fd, cb) |
| | except NotImplementedError: |
| | from anyio._core._asyncio_selector_thread import get_selector |
| |
|
| | selector = get_selector() |
| | selector.add_reader(fd, cb) |
| | remove_reader = selector.remove_reader |
| | else: |
| | remove_reader = loop.remove_reader |
| |
|
| | read_events[fd] = fut |
| | try: |
| | success = await fut |
| | finally: |
| | try: |
| | del read_events[fd] |
| | except KeyError: |
| | pass |
| | else: |
| | remove_reader(fd) |
| |
|
| | if not success: |
| | raise ClosedResourceError |
| |
|
| | @classmethod |
| | async def wait_writable(cls, obj: FileDescriptorLike) -> None: |
| | try: |
| | write_events = _write_events.get() |
| | except LookupError: |
| | write_events = {} |
| | _write_events.set(write_events) |
| |
|
| | fd = obj if isinstance(obj, int) else obj.fileno() |
| | if write_events.get(fd): |
| | raise BusyResourceError("writing to") |
| |
|
| | loop = get_running_loop() |
| | fut: asyncio.Future[bool] = loop.create_future() |
| |
|
| | def cb() -> None: |
| | try: |
| | del write_events[fd] |
| | except KeyError: |
| | pass |
| | else: |
| | remove_writer(fd) |
| |
|
| | try: |
| | fut.set_result(True) |
| | except asyncio.InvalidStateError: |
| | pass |
| |
|
| | try: |
| | loop.add_writer(fd, cb) |
| | except NotImplementedError: |
| | from anyio._core._asyncio_selector_thread import get_selector |
| |
|
| | selector = get_selector() |
| | selector.add_writer(fd, cb) |
| | remove_writer = selector.remove_writer |
| | else: |
| | remove_writer = loop.remove_writer |
| |
|
| | write_events[fd] = fut |
| | try: |
| | success = await fut |
| | finally: |
| | try: |
| | del write_events[fd] |
| | except KeyError: |
| | pass |
| | else: |
| | remove_writer(fd) |
| |
|
| | if not success: |
| | raise ClosedResourceError |
| |
|
| | @classmethod |
| | def notify_closing(cls, obj: FileDescriptorLike) -> None: |
| | fd = obj if isinstance(obj, int) else obj.fileno() |
| | loop = get_running_loop() |
| |
|
| | try: |
| | write_events = _write_events.get() |
| | except LookupError: |
| | pass |
| | else: |
| | try: |
| | fut = write_events.pop(fd) |
| | except KeyError: |
| | pass |
| | else: |
| | try: |
| | fut.set_result(False) |
| | except asyncio.InvalidStateError: |
| | pass |
| |
|
| | try: |
| | loop.remove_writer(fd) |
| | except NotImplementedError: |
| | from anyio._core._asyncio_selector_thread import get_selector |
| |
|
| | get_selector().remove_writer(fd) |
| |
|
| | try: |
| | read_events = _read_events.get() |
| | except LookupError: |
| | pass |
| | else: |
| | try: |
| | fut = read_events.pop(fd) |
| | except KeyError: |
| | pass |
| | else: |
| | try: |
| | fut.set_result(False) |
| | except asyncio.InvalidStateError: |
| | pass |
| |
|
| | try: |
| | loop.remove_reader(fd) |
| | except NotImplementedError: |
| | from anyio._core._asyncio_selector_thread import get_selector |
| |
|
| | get_selector().remove_reader(fd) |
| |
|
| | @classmethod |
| | async def wrap_listener_socket(cls, sock: socket.socket) -> SocketListener: |
| | return TCPSocketListener(sock) |
| |
|
| | @classmethod |
| | async def wrap_stream_socket(cls, sock: socket.socket) -> SocketStream: |
| | transport, protocol = await get_running_loop().create_connection( |
| | StreamProtocol, sock=sock |
| | ) |
| | return SocketStream(transport, protocol) |
| |
|
| | @classmethod |
| | async def wrap_unix_stream_socket(cls, sock: socket.socket) -> UNIXSocketStream: |
| | return UNIXSocketStream(sock) |
| |
|
| | @classmethod |
| | async def wrap_udp_socket(cls, sock: socket.socket) -> UDPSocket: |
| | transport, protocol = await get_running_loop().create_datagram_endpoint( |
| | DatagramProtocol, sock=sock |
| | ) |
| | return UDPSocket(transport, protocol) |
| |
|
| | @classmethod |
| | async def wrap_connected_udp_socket(cls, sock: socket.socket) -> ConnectedUDPSocket: |
| | transport, protocol = await get_running_loop().create_datagram_endpoint( |
| | DatagramProtocol, sock=sock |
| | ) |
| | return ConnectedUDPSocket(transport, protocol) |
| |
|
| | @classmethod |
| | async def wrap_unix_datagram_socket(cls, sock: socket.socket) -> UNIXDatagramSocket: |
| | return UNIXDatagramSocket(sock) |
| |
|
| | @classmethod |
| | async def wrap_connected_unix_datagram_socket( |
| | cls, sock: socket.socket |
| | ) -> ConnectedUNIXDatagramSocket: |
| | return ConnectedUNIXDatagramSocket(sock) |
| |
|
| | @classmethod |
| | def current_default_thread_limiter(cls) -> CapacityLimiter: |
| | try: |
| | return _default_thread_limiter.get() |
| | except LookupError: |
| | limiter = CapacityLimiter(40) |
| | _default_thread_limiter.set(limiter) |
| | return limiter |
| |
|
| | @classmethod |
| | def open_signal_receiver( |
| | cls, *signals: Signals |
| | ) -> AbstractContextManager[AsyncIterator[Signals]]: |
| | return _SignalReceiver(signals) |
| |
|
| | @classmethod |
| | def get_current_task(cls) -> TaskInfo: |
| | return AsyncIOTaskInfo(current_task()) |
| |
|
| | @classmethod |
| | def get_running_tasks(cls) -> Sequence[TaskInfo]: |
| | return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()] |
| |
|
| | @classmethod |
| | async def wait_all_tasks_blocked(cls) -> None: |
| | await cls.checkpoint() |
| | this_task = current_task() |
| | while True: |
| | for task in all_tasks(): |
| | if task is this_task: |
| | continue |
| |
|
| | waiter = task._fut_waiter |
| | if waiter is None or waiter.done(): |
| | await sleep(0.1) |
| | break |
| | else: |
| | return |
| |
|
| | @classmethod |
| | def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: |
| | return TestRunner(**options) |
| |
|
| |
|
| | backend_class = AsyncIOBackend |
| |
|