Kris8an's picture
Upload folder using huggingface_hub
a06facb verified
# Copyright 2022 Amethyst Reese
# Licensed under the MIT license
"""
Friendlier version of asyncio standard library.
Provisional library. Must be imported as `aioitertools.asyncio`.
"""
import asyncio
import time
from typing import (
Any,
AsyncGenerator,
AsyncIterable,
Awaitable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)
from .builtins import iter as aiter, maybe_await
from .types import AnyIterable, AsyncIterator, MaybeAwaitable, T
async def as_completed(
aws: Iterable[Awaitable[T]],
*,
timeout: Optional[float] = None,
) -> AsyncIterator[T]:
"""
Run awaitables in `aws` concurrently, and yield results as they complete.
Unlike `asyncio.as_completed`, this yields actual results, and does not require
awaiting each item in the iterable.
Cancels all remaining awaitables if a timeout is given and the timeout threshold
is reached.
Example::
async for value in as_completed(futures):
... # use value immediately
"""
done: Set[Awaitable[T]] = set()
pending: Set[Awaitable[T]] = {asyncio.ensure_future(a) for a in aws}
remaining: Optional[float] = None
if timeout and timeout > 0:
threshold = time.time() + timeout
else:
timeout = None
while pending:
if timeout:
remaining = threshold - time.time()
if remaining <= 0:
for fut in pending:
if isinstance(fut, asyncio.Future):
fut.cancel()
else: # pragma: no cover
pass
raise asyncio.TimeoutError()
# asyncio.Future inherits from typing.Awaitable
# asyncio.wait takes Iterable[Union[Future, Generator, Awaitable]], but
# returns Tuple[Set[Future], Set[Future]. Because mypy doesn't like assigning
# these values to existing Set[Awaitable] or even Set[Union[Awaitable, Future]],
# we need to first cast the results to something that we can actually use
# asyncio.Future: https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/futures.pyi#L30 # noqa: E501
# asyncio.wait(): https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/tasks.pyi#L89 # noqa: E501
done, pending = cast(
Tuple[Set[Awaitable[T]], Set[Awaitable[T]]],
await asyncio.wait(
pending,
timeout=remaining,
return_when=asyncio.FIRST_COMPLETED,
),
)
for item in done:
yield await item
async def as_generated(
iterables: Iterable[AsyncIterable[T]],
*,
return_exceptions: bool = False,
) -> AsyncIterable[T]:
"""
Yield results from one or more async iterables, in the order they are produced.
Like :func:`as_completed`, but for async iterators or generators instead of futures.
Creates a separate task to drain each iterable, and a single queue for results.
If ``return_exceptions`` is ``False``, then any exception will be raised, and
pending iterables and tasks will be cancelled, and async generators will be closed.
If ``return_exceptions`` is ``True``, any exceptions will be yielded as results,
and execution will continue until all iterables have been fully consumed.
Example::
async def generator(x):
for i in range(x):
yield i
gen1 = generator(10)
gen2 = generator(12)
async for value in as_generated([gen1, gen2]):
... # intermixed values yielded from gen1 and gen2
"""
exc_queue: asyncio.Queue[Exception] = asyncio.Queue()
queue: asyncio.Queue[T] = asyncio.Queue()
async def tailer(iter: AsyncIterable[T]) -> None:
try:
async for item in iter:
await queue.put(item)
except asyncio.CancelledError:
if isinstance(iter, AsyncGenerator): # pragma:nocover
await iter.aclose()
raise
except Exception as e:
await exc_queue.put(e)
tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]
pending = set(tasks)
try:
while pending:
try:
exc = exc_queue.get_nowait()
if return_exceptions:
yield exc # type: ignore
else:
raise exc
except asyncio.QueueEmpty:
pass
try:
value = queue.get_nowait()
yield value
except asyncio.QueueEmpty:
for task in list(pending):
if task.done():
pending.remove(task)
await asyncio.sleep(0.001)
except (asyncio.CancelledError, GeneratorExit):
pass
finally:
for task in tasks:
if not task.done():
task.cancel()
for task in tasks:
try:
await task
except asyncio.CancelledError:
pass
async def gather(
*args: Awaitable[T],
return_exceptions: bool = False,
limit: int = -1,
) -> List[Any]:
"""
Like asyncio.gather but with a limit on concurrency.
Note that all results are buffered.
If gather is cancelled all tasks that were internally created and still pending
will be cancelled as well.
Example::
futures = [some_coro(i) for i in range(10)]
results = await gather(*futures, limit=2)
"""
# For detecting input duplicates and reconciling them at the end
input_map: Dict[Awaitable[T], List[int]] = {}
# This is keyed on what we'll get back from asyncio.wait
pos: Dict[asyncio.Future[T], int] = {}
ret: List[Any] = [None] * len(args)
pending: Set[asyncio.Future[T]] = set()
done: Set[asyncio.Future[T]] = set()
next_arg = 0
while True:
while next_arg < len(args) and (limit == -1 or len(pending) < limit):
# We have to defer the creation of the Task as long as possible
# because once we do, it starts executing, regardless of what we
# have in the pending set.
if args[next_arg] in input_map:
input_map[args[next_arg]].append(next_arg)
else:
# We call ensure_future directly to ensure that we have a Task
# because the return value of asyncio.wait will be an implicit
# task otherwise, and we won't be able to know which input it
# corresponds to.
task: asyncio.Future[T] = asyncio.ensure_future(args[next_arg])
pending.add(task)
pos[task] = next_arg
input_map[args[next_arg]] = [next_arg]
next_arg += 1
# pending might be empty if the last items of args were dupes;
# asyncio.wait([]) will raise an exception.
if pending:
try:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
for x in done:
if return_exceptions and x.exception():
ret[pos[x]] = x.exception()
else:
ret[pos[x]] = x.result()
except asyncio.CancelledError:
# Since we created these tasks we should cancel them
for x in pending:
x.cancel()
# we insure that all tasks are cancelled before we raise
await asyncio.gather(*pending, return_exceptions=True)
raise
if not pending and next_arg == len(args):
break
for lst in input_map.values():
for i in range(1, len(lst)):
ret[lst[i]] = ret[lst[0]]
return ret
async def gather_iter(
itr: AnyIterable[MaybeAwaitable[T]],
return_exceptions: bool = False,
limit: int = -1,
) -> List[T]:
"""
Wrapper around gather to handle gathering an iterable instead of *args.
Note that the iterable values don't have to be awaitable.
"""
return await gather(
*[maybe_await(i) async for i in aiter(itr)],
return_exceptions=return_exceptions,
limit=limit,
)