Spaces:
Runtime error
Runtime error
| # Adapted with permission from the EdgeDB project; | |
| # license: PSFL. | |
| __all__ = ["TaskGroup"] | |
| from . import events | |
| from . import exceptions | |
| from . import tasks | |
| class TaskGroup: | |
| """Asynchronous context manager for managing groups of tasks. | |
| Example use: | |
| async with asyncio.TaskGroup() as group: | |
| task1 = group.create_task(some_coroutine(...)) | |
| task2 = group.create_task(other_coroutine(...)) | |
| print("Both tasks have completed now.") | |
| All tasks are awaited when the context manager exits. | |
| Any exceptions other than `asyncio.CancelledError` raised within | |
| a task will cancel all remaining tasks and wait for them to exit. | |
| The exceptions are then combined and raised as an `ExceptionGroup`. | |
| """ | |
| def __init__(self): | |
| self._entered = False | |
| self._exiting = False | |
| self._aborting = False | |
| self._loop = None | |
| self._parent_task = None | |
| self._parent_cancel_requested = False | |
| self._tasks = set() | |
| self._errors = [] | |
| self._base_error = None | |
| self._on_completed_fut = None | |
| def __repr__(self): | |
| info = [''] | |
| if self._tasks: | |
| info.append(f'tasks={len(self._tasks)}') | |
| if self._errors: | |
| info.append(f'errors={len(self._errors)}') | |
| if self._aborting: | |
| info.append('cancelling') | |
| elif self._entered: | |
| info.append('entered') | |
| info_str = ' '.join(info) | |
| return f'<TaskGroup{info_str}>' | |
| async def __aenter__(self): | |
| if self._entered: | |
| raise RuntimeError( | |
| f"TaskGroup {self!r} has been already entered") | |
| self._entered = True | |
| if self._loop is None: | |
| self._loop = events.get_running_loop() | |
| self._parent_task = tasks.current_task(self._loop) | |
| if self._parent_task is None: | |
| raise RuntimeError( | |
| f'TaskGroup {self!r} cannot determine the parent task') | |
| return self | |
| async def __aexit__(self, et, exc, tb): | |
| self._exiting = True | |
| if (exc is not None and | |
| self._is_base_error(exc) and | |
| self._base_error is None): | |
| self._base_error = exc | |
| propagate_cancellation_error = \ | |
| exc if et is exceptions.CancelledError else None | |
| if self._parent_cancel_requested: | |
| # If this flag is set we *must* call uncancel(). | |
| if self._parent_task.uncancel() == 0: | |
| # If there are no pending cancellations left, | |
| # don't propagate CancelledError. | |
| propagate_cancellation_error = None | |
| if et is not None: | |
| if not self._aborting: | |
| # Our parent task is being cancelled: | |
| # | |
| # async with TaskGroup() as g: | |
| # g.create_task(...) | |
| # await ... # <- CancelledError | |
| # | |
| # or there's an exception in "async with": | |
| # | |
| # async with TaskGroup() as g: | |
| # g.create_task(...) | |
| # 1 / 0 | |
| # | |
| self._abort() | |
| # We use while-loop here because "self._on_completed_fut" | |
| # can be cancelled multiple times if our parent task | |
| # is being cancelled repeatedly (or even once, when | |
| # our own cancellation is already in progress) | |
| while self._tasks: | |
| if self._on_completed_fut is None: | |
| self._on_completed_fut = self._loop.create_future() | |
| try: | |
| await self._on_completed_fut | |
| except exceptions.CancelledError as ex: | |
| if not self._aborting: | |
| # Our parent task is being cancelled: | |
| # | |
| # async def wrapper(): | |
| # async with TaskGroup() as g: | |
| # g.create_task(foo) | |
| # | |
| # "wrapper" is being cancelled while "foo" is | |
| # still running. | |
| propagate_cancellation_error = ex | |
| self._abort() | |
| self._on_completed_fut = None | |
| assert not self._tasks | |
| if self._base_error is not None: | |
| raise self._base_error | |
| # Propagate CancelledError if there is one, except if there | |
| # are other errors -- those have priority. | |
| if propagate_cancellation_error and not self._errors: | |
| raise propagate_cancellation_error | |
| if et is not None and et is not exceptions.CancelledError: | |
| self._errors.append(exc) | |
| if self._errors: | |
| # Exceptions are heavy objects that can have object | |
| # cycles (bad for GC); let's not keep a reference to | |
| # a bunch of them. | |
| try: | |
| me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) | |
| raise me from None | |
| finally: | |
| self._errors = None | |
| def create_task(self, coro, *, name=None, context=None): | |
| """Create a new task in this group and return it. | |
| Similar to `asyncio.create_task`. | |
| """ | |
| if not self._entered: | |
| raise RuntimeError(f"TaskGroup {self!r} has not been entered") | |
| if self._exiting and not self._tasks: | |
| raise RuntimeError(f"TaskGroup {self!r} is finished") | |
| if self._aborting: | |
| raise RuntimeError(f"TaskGroup {self!r} is shutting down") | |
| if context is None: | |
| task = self._loop.create_task(coro) | |
| else: | |
| task = self._loop.create_task(coro, context=context) | |
| tasks._set_task_name(task, name) | |
| task.add_done_callback(self._on_task_done) | |
| self._tasks.add(task) | |
| return task | |
| # Since Python 3.8 Tasks propagate all exceptions correctly, | |
| # except for KeyboardInterrupt and SystemExit which are | |
| # still considered special. | |
| def _is_base_error(self, exc: BaseException) -> bool: | |
| assert isinstance(exc, BaseException) | |
| return isinstance(exc, (SystemExit, KeyboardInterrupt)) | |
| def _abort(self): | |
| self._aborting = True | |
| for t in self._tasks: | |
| if not t.done(): | |
| t.cancel() | |
| def _on_task_done(self, task): | |
| self._tasks.discard(task) | |
| if self._on_completed_fut is not None and not self._tasks: | |
| if not self._on_completed_fut.done(): | |
| self._on_completed_fut.set_result(True) | |
| if task.cancelled(): | |
| return | |
| exc = task.exception() | |
| if exc is None: | |
| return | |
| self._errors.append(exc) | |
| if self._is_base_error(exc) and self._base_error is None: | |
| self._base_error = exc | |
| if self._parent_task.done(): | |
| # Not sure if this case is possible, but we want to handle | |
| # it anyways. | |
| self._loop.call_exception_handler({ | |
| 'message': f'Task {task!r} has errored out but its parent ' | |
| f'task {self._parent_task} is already completed', | |
| 'exception': exc, | |
| 'task': task, | |
| }) | |
| return | |
| if not self._aborting and not self._parent_cancel_requested: | |
| # If parent task *is not* being cancelled, it means that we want | |
| # to manually cancel it to abort whatever is being run right now | |
| # in the TaskGroup. But we want to mark parent task as | |
| # "not cancelled" later in __aexit__. Example situation that | |
| # we need to handle: | |
| # | |
| # async def foo(): | |
| # try: | |
| # async with TaskGroup() as g: | |
| # g.create_task(crash_soon()) | |
| # await something # <- this needs to be canceled | |
| # # by the TaskGroup, e.g. | |
| # # foo() needs to be cancelled | |
| # except Exception: | |
| # # Ignore any exceptions raised in the TaskGroup | |
| # pass | |
| # await something_else # this line has to be called | |
| # # after TaskGroup is finished. | |
| self._abort() | |
| self._parent_cancel_requested = True | |
| self._parent_task.cancel() | |