Kris8an's picture
Upload folder using huggingface_hub
a06facb verified
# Copyright 2022 Amethyst Reese
# Licensed under the MIT license
import asyncio
import operator
from unittest import TestCase
import aioitertools as ait
from .helpers import async_test
slist = ["A", "B", "C"]
srange = range(1, 4)
class ItertoolsTest(TestCase):
@async_test
async def test_accumulate_range_default(self):
it = ait.accumulate(srange)
for k in [1, 3, 6]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_accumulate_range_function(self):
it = ait.accumulate(srange, func=operator.mul)
for k in [1, 2, 6]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_accumulate_range_coroutine(self):
async def mul(a, b):
return a * b
it = ait.accumulate(srange, func=mul)
for k in [1, 2, 6]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_accumulate_gen_function(self):
async def gen():
yield 1
yield 2
yield 4
it = ait.accumulate(gen(), func=operator.mul)
for k in [1, 2, 8]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_accumulate_gen_coroutine(self):
async def mul(a, b):
return a * b
async def gen():
yield 1
yield 2
yield 4
it = ait.accumulate(gen(), func=mul)
for k in [1, 2, 8]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_accumulate_empty(self):
values = []
async for value in ait.accumulate([]):
values.append(value)
self.assertEqual(values, [])
@async_test
async def test_chain_lists(self):
it = ait.chain(slist, srange)
for k in ["A", "B", "C", 1, 2, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_chain_list_gens(self):
async def gen():
for k in range(2, 9, 2):
yield k
it = ait.chain(slist, gen())
for k in ["A", "B", "C", 2, 4, 6, 8]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_chain_from_iterable(self):
async def gen():
for k in range(2, 9, 2):
yield k
it = ait.chain.from_iterable([slist, gen()])
for k in ["A", "B", "C", 2, 4, 6, 8]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_chain_from_iterable_parameter_expansion_gen(self):
async def gen():
for k in range(2, 9, 2):
yield k
async def parameters_gen():
yield slist
yield gen()
it = ait.chain.from_iterable(parameters_gen())
for k in ["A", "B", "C", 2, 4, 6, 8]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_combinations(self):
it = ait.combinations(range(4), 3)
for k in [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_combinations_with_replacement(self):
it = ait.combinations_with_replacement(slist, 2)
for k in [
("A", "A"),
("A", "B"),
("A", "C"),
("B", "B"),
("B", "C"),
("C", "C"),
]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_compress_list(self):
data = range(10)
selectors = [0, 1, 1, 0, 0, 0, 1, 0, 1, 0]
it = ait.compress(data, selectors)
for k in [1, 2, 6, 8]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_compress_gen(self):
data = "abcdefghijkl"
selectors = ait.cycle([1, 0, 0])
it = ait.compress(data, selectors)
for k in ["a", "d", "g", "j"]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_count_bare(self):
it = ait.count()
for k in [0, 1, 2, 3]:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_count_start(self):
it = ait.count(42)
for k in [42, 43, 44, 45]:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_count_start_step(self):
it = ait.count(42, 3)
for k in [42, 45, 48, 51]:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_count_negative(self):
it = ait.count(step=-2)
for k in [0, -2, -4, -6]:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_cycle_list(self):
it = ait.cycle(slist)
for k in ["A", "B", "C", "A", "B", "C", "A", "B"]:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_cycle_gen(self):
async def gen():
yield 1
yield 2
yield 42
it = ait.cycle(gen())
for k in [1, 2, 42, 1, 2, 42, 1, 2]:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_dropwhile_empty(self):
def pred(x):
return x < 2
result = await ait.list(ait.dropwhile(pred, []))
self.assertEqual(result, [])
@async_test
async def test_dropwhile_function_list(self):
def pred(x):
return x < 2
it = ait.dropwhile(pred, srange)
for k in [2, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_dropwhile_function_gen(self):
def pred(x):
return x < 2
async def gen():
yield 1
yield 2
yield 42
it = ait.dropwhile(pred, gen())
for k in [2, 42]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_dropwhile_coroutine_list(self):
async def pred(x):
return x < 2
it = ait.dropwhile(pred, srange)
for k in [2, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_dropwhile_coroutine_gen(self):
async def pred(x):
return x < 2
async def gen():
yield 1
yield 2
yield 42
it = ait.dropwhile(pred, gen())
for k in [2, 42]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_filterfalse_function_list(self):
def pred(x):
return x % 2 == 0
it = ait.filterfalse(pred, srange)
for k in [1, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_filterfalse_coroutine_list(self):
async def pred(x):
return x % 2 == 0
it = ait.filterfalse(pred, srange)
for k in [1, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_groupby_list(self):
data = "aaabba"
it = ait.groupby(data)
for k in [("a", ["a", "a", "a"]), ("b", ["b", "b"]), ("a", ["a"])]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_groupby_list_key(self):
data = "aAabBA"
it = ait.groupby(data, key=str.lower)
for k in [("a", ["a", "A", "a"]), ("b", ["b", "B"]), ("a", ["A"])]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_groupby_gen(self):
async def gen():
for c in "aaabba":
yield c
it = ait.groupby(gen())
for k in [("a", ["a", "a", "a"]), ("b", ["b", "b"]), ("a", ["a"])]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_groupby_gen_key(self):
async def gen():
for c in "aAabBA":
yield c
it = ait.groupby(gen(), key=str.lower)
for k in [("a", ["a", "A", "a"]), ("b", ["b", "B"]), ("a", ["A"])]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_groupby_empty(self):
async def gen():
for _ in range(0):
yield # Force generator with no actual iteration
async for _ in ait.groupby(gen()):
self.fail("No iteration should have happened")
@async_test
async def test_islice_bad_range(self):
with self.assertRaisesRegex(ValueError, "must pass stop index"):
async for _ in ait.islice([1, 2]):
pass
with self.assertRaisesRegex(ValueError, "too many arguments"):
async for _ in ait.islice([1, 2], 1, 2, 3, 4):
pass
@async_test
async def test_islice_stop_zero(self):
values = []
async for value in ait.islice(range(5), 0):
values.append(value)
self.assertEqual(values, [])
@async_test
async def test_islice_range_stop(self):
it = ait.islice(srange, 2)
for k in [1, 2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_islice_range_start_step(self):
it = ait.islice(srange, 0, None, 2)
for k in [1, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_islice_range_start_stop(self):
it = ait.islice(srange, 1, 3)
for k in [2, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_islice_range_start_stop_step(self):
it = ait.islice(srange, 1, 3, 2)
for k in [2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_islice_gen_stop(self):
async def gen():
yield 1
yield 2
yield 3
yield 4
gen_it = gen()
it = ait.islice(gen_it, 2)
for k in [1, 2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
assert await ait.list(gen_it) == [3, 4]
@async_test
async def test_islice_gen_start_step(self):
async def gen():
yield 1
yield 2
yield 3
yield 4
it = ait.islice(gen(), 1, None, 2)
for k in [2, 4]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_islice_gen_start_stop(self):
async def gen():
yield 1
yield 2
yield 3
yield 4
it = ait.islice(gen(), 1, 3)
for k in [2, 3]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_islice_gen_start_stop_step(self):
async def gen():
yield 1
yield 2
yield 3
yield 4
gen_it = gen()
it = ait.islice(gen_it, 1, 3, 2)
for k in [2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
assert await ait.list(gen_it) == [4]
@async_test
async def test_permutations_list(self):
it = ait.permutations(srange, r=2)
for k in [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_permutations_gen(self):
async def gen():
yield 1
yield 2
yield 3
it = ait.permutations(gen(), r=2)
for k in [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_product_list(self):
it = ait.product([1, 2], [6, 7])
for k in [(1, 6), (1, 7), (2, 6), (2, 7)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_product_gen(self):
async def gen(x):
yield x
yield x + 1
it = ait.product(gen(1), gen(6))
for k in [(1, 6), (1, 7), (2, 6), (2, 7)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_repeat(self):
it = ait.repeat(42)
for k in [42] * 10:
self.assertEqual(await ait.next(it), k)
@async_test
async def test_repeat_limit(self):
it = ait.repeat(42, 5)
for k in [42] * 5:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_starmap_function_list(self):
data = [slist[:2], slist[1:], slist]
def concat(*args):
return "".join(args)
it = ait.starmap(concat, data)
for k in ["AB", "BC", "ABC"]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_starmap_function_gen(self):
def gen():
yield slist[:2]
yield slist[1:]
yield slist
def concat(*args):
return "".join(args)
it = ait.starmap(concat, gen())
for k in ["AB", "BC", "ABC"]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_starmap_coroutine_list(self):
data = [slist[:2], slist[1:], slist]
async def concat(*args):
return "".join(args)
it = ait.starmap(concat, data)
for k in ["AB", "BC", "ABC"]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_starmap_coroutine_gen(self):
async def gen():
yield slist[:2]
yield slist[1:]
yield slist
async def concat(*args):
return "".join(args)
it = ait.starmap(concat, gen())
for k in ["AB", "BC", "ABC"]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_takewhile_empty(self):
def pred(x):
return x < 3
values = await ait.list(ait.takewhile(pred, []))
self.assertEqual(values, [])
@async_test
async def test_takewhile_function_list(self):
def pred(x):
return x < 3
it = ait.takewhile(pred, srange)
for k in [1, 2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_takewhile_function_gen(self):
async def gen():
yield 1
yield 2
yield 3
def pred(x):
return x < 3
it = ait.takewhile(pred, gen())
for k in [1, 2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_takewhile_coroutine_list(self):
async def pred(x):
return x < 3
it = ait.takewhile(pred, srange)
for k in [1, 2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_takewhile_coroutine_gen(self):
def gen():
yield 1
yield 2
yield 3
async def pred(x):
return x < 3
it = ait.takewhile(pred, gen())
for k in [1, 2]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_tee_list_two(self):
it1, it2 = ait.tee(slist * 2)
for k in slist * 2:
a, b = await asyncio.gather(ait.next(it1), ait.next(it2))
self.assertEqual(a, b)
self.assertEqual(a, k)
self.assertEqual(b, k)
for it in [it1, it2]:
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_tee_list_six(self):
itrs = ait.tee(slist * 2, n=6)
for k in slist * 2:
values = await asyncio.gather(*[ait.next(it) for it in itrs])
for value in values:
self.assertEqual(value, k)
for it in itrs:
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_tee_gen_two(self):
async def gen():
yield 1
yield 4
yield 9
yield 16
it1, it2 = ait.tee(gen())
for k in [1, 4, 9, 16]:
a, b = await asyncio.gather(ait.next(it1), ait.next(it2))
self.assertEqual(a, b)
self.assertEqual(a, k)
self.assertEqual(b, k)
for it in [it1, it2]:
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_tee_gen_six(self):
async def gen():
yield 1
yield 4
yield 9
yield 16
itrs = ait.tee(gen(), n=6)
for k in [1, 4, 9, 16]:
values = await asyncio.gather(*[ait.next(it) for it in itrs])
for value in values:
self.assertEqual(value, k)
for it in itrs:
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_tee_propagate_exception(self):
class MyError(Exception):
pass
async def gen():
yield 1
yield 2
raise MyError
async def consumer(it):
result = 0
async for item in it:
result += item
return result
it1, it2 = ait.tee(gen())
values = await asyncio.gather(
consumer(it1),
consumer(it2),
return_exceptions=True,
)
for value in values:
self.assertIsInstance(value, MyError)
@async_test
async def test_zip_longest_range(self):
a = range(3)
b = range(5)
it = ait.zip_longest(a, b)
for k in [(0, 0), (1, 1), (2, 2), (None, 3), (None, 4)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_zip_longest_fillvalue(self):
async def gen():
yield 1
yield 4
yield 9
yield 16
a = gen()
b = range(5)
it = ait.zip_longest(a, b, fillvalue=42)
for k in [(1, 0), (4, 1), (9, 2), (16, 3), (42, 4)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaises(StopAsyncIteration):
await ait.next(it)
@async_test
async def test_zip_longest_exception(self):
async def gen():
yield 1
yield 2
raise Exception("fake error")
a = gen()
b = ait.repeat(5)
it = ait.zip_longest(a, b)
for k in [(1, 5), (2, 5)]:
self.assertEqual(await ait.next(it), k)
with self.assertRaisesRegex(Exception, "fake error"):
await ait.next(it)