Kris8an's picture
Upload folder using huggingface_hub
a06facb verified
# Copyright 2022 Amethyst Reese
# Licensed under the MIT license
import asyncio
from unittest import TestCase
import aioitertools as ait
import aioitertools.asyncio as aio
from .helpers import async_test
slist = ["A", "B", "C"]
srange = range(3)
class AsyncioTest(TestCase):
def test_import(self):
self.assertEqual(ait.asyncio, aio)
@async_test
async def test_as_completed(self):
async def sleepy(number, duration):
await asyncio.sleep(duration)
return number
pairs = [(1, 0.3), (2, 0.1), (3, 0.5)]
expected = [2, 1, 3]
futures = [sleepy(*pair) for pair in pairs]
results = await ait.list(aio.as_completed(futures))
self.assertEqual(results, expected)
futures = [sleepy(*pair) for pair in pairs]
results = []
async for value in aio.as_completed(futures):
results.append(value)
self.assertEqual(results, expected)
@async_test
async def test_as_completed_timeout(self):
calls = [(1.0,), (0.1,)]
futures = [asyncio.sleep(*args) for args in calls]
with self.assertRaises(asyncio.TimeoutError):
await ait.list(aio.as_completed(futures, timeout=0.5))
futures = [asyncio.sleep(*args) for args in calls]
results = 0
with self.assertRaises(asyncio.TimeoutError):
async for _ in aio.as_completed(futures, timeout=0.5):
results += 1
self.assertEqual(results, 1)
@async_test
async def test_as_generated(self):
async def gen():
for i in range(10):
yield i
await asyncio.sleep(0)
gens = [gen(), gen(), gen()]
expected = list(range(10)) * 3
results = []
async for value in aio.as_generated(gens):
results.append(value)
self.assertEqual(30, len(results))
self.assertListEqual(sorted(expected), sorted(results))
@async_test
async def test_as_generated_exception(self):
async def gen1():
for i in range(3):
yield i
await asyncio.sleep(0)
raise Exception("fake")
async def gen2():
for i in range(10):
yield i
await asyncio.sleep(0)
gens = [gen1(), gen2()]
results = []
with self.assertRaisesRegex(Exception, "fake"):
async for value in aio.as_generated(gens):
results.append(value)
self.assertNotIn(10, results)
@async_test
async def test_as_generated_return_exception(self):
async def gen1():
for i in range(3):
yield i
await asyncio.sleep(0)
raise Exception("fake")
async def gen2():
for i in range(10):
yield i
await asyncio.sleep(0)
gens = [gen1(), gen2()]
expected = list(range(3)) + list(range(10))
errors = []
results = []
async for value in aio.as_generated(gens, return_exceptions=True):
if isinstance(value, Exception):
errors.append(value)
else:
results.append(value)
self.assertListEqual(sorted(expected), sorted(results))
self.assertEqual(1, len(errors))
self.assertIsInstance(errors[0], Exception)
@async_test
async def test_as_generated_task_cancelled(self):
async def gen(max: int = 10):
for i in range(5):
if i > max:
raise asyncio.CancelledError
yield i
await asyncio.sleep(0)
gens = [gen(2), gen()]
expected = list(range(3)) + list(range(5))
results = []
async for value in aio.as_generated(gens):
results.append(value)
self.assertListEqual(sorted(expected), sorted(results))
@async_test
async def test_as_generated_cancelled(self):
async def gen():
for i in range(5):
yield i
await asyncio.sleep(0.1)
expected = [0, 0, 1, 1]
results = []
async def foo():
gens = [gen(), gen()]
async for value in aio.as_generated(gens):
results.append(value)
return results
task = asyncio.ensure_future(foo())
await asyncio.sleep(0.15)
task.cancel()
await task
self.assertListEqual(sorted(expected), sorted(results))
@async_test
async def test_gather_input_types(self):
async def fn(arg):
await asyncio.sleep(0.001)
return arg
fns = [fn(1), asyncio.ensure_future(fn(2))]
if hasattr(asyncio, "create_task"):
# 3.7 only
fns.append(asyncio.create_task(fn(3)))
else:
fns.append(fn(3))
result = await aio.gather(*fns)
self.assertEqual([1, 2, 3], result)
@async_test
async def test_gather_limited(self):
max_counter = 0
counter = 0
async def fn(arg):
nonlocal counter, max_counter
counter += 1
max_counter = max(max_counter, counter)
await asyncio.sleep(0.001)
counter -= 1
return arg
# Limit of 2
result = await aio.gather(*[fn(i) for i in range(10)], limit=2)
self.assertEqual(2, max_counter)
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], result)
# No limit
result = await aio.gather(*[fn(i) for i in range(10)])
self.assertEqual(
10, max_counter
) # TODO: on a loaded machine this might be less?
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], result)
@async_test
async def test_gather_limited_dupes(self):
async def fn(arg):
await asyncio.sleep(0.001)
return arg
f = fn(1)
g = fn(2)
result = await aio.gather(f, f, f, g, f, g, limit=2)
self.assertEqual([1, 1, 1, 2, 1, 2], result)
f = fn(1)
g = fn(2)
result = await aio.gather(f, f, f, g, f, g)
self.assertEqual([1, 1, 1, 2, 1, 2], result)
@async_test
async def test_gather_with_exceptions(self):
class MyException(Exception):
pass
async def fn(arg, fail=False):
await asyncio.sleep(arg)
if fail:
raise MyException(arg)
return arg
with self.assertRaises(MyException):
await aio.gather(fn(0.002, fail=True), fn(0.001))
result = await aio.gather(
fn(0.002, fail=True), fn(0.001), return_exceptions=True
)
self.assertEqual(result[1], 0.001)
self.assertIsInstance(result[0], MyException)
@async_test
async def test_gather_cancel(self):
cancelled = False
started = False
async def _fn():
nonlocal started, cancelled
try:
started = True
await asyncio.sleep(10) # might as well be forever
except asyncio.CancelledError:
nonlocal cancelled
cancelled = True
raise
async def _gather():
await aio.gather(_fn())
if hasattr(asyncio, "create_task"):
# 3.7+ only
task = asyncio.create_task(_gather())
else:
task = asyncio.ensure_future(_gather())
# to insure the gather actually runs
await asyncio.sleep(0)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(started)
self.assertTrue(cancelled)