tmp
/
pip-install-ghxuqwgs
/numpy_78e94bf2b6094bf9a1f3d92042f9bf46
/numpy
/lib
/tests
/test_arrayterator.py
| from __future__ import division, absolute_import, print_function | |
| from operator import mul | |
| from functools import reduce | |
| import numpy as np | |
| from numpy.random import randint | |
| from numpy.lib import Arrayterator | |
| from numpy.testing import assert_ | |
| def test(): | |
| np.random.seed(np.arange(10)) | |
| # Create a random array | |
| ndims = randint(5)+1 | |
| shape = tuple(randint(10)+1 for dim in range(ndims)) | |
| els = reduce(mul, shape) | |
| a = np.arange(els) | |
| a.shape = shape | |
| buf_size = randint(2*els) | |
| b = Arrayterator(a, buf_size) | |
| # Check that each block has at most ``buf_size`` elements | |
| for block in b: | |
| assert_(len(block.flat) <= (buf_size or els)) | |
| # Check that all elements are iterated correctly | |
| assert_(list(b.flat) == list(a.flat)) | |
| # Slice arrayterator | |
| start = [randint(dim) for dim in shape] | |
| stop = [randint(dim)+1 for dim in shape] | |
| step = [randint(dim)+1 for dim in shape] | |
| slice_ = tuple(slice(*t) for t in zip(start, stop, step)) | |
| c = b[slice_] | |
| d = a[slice_] | |
| # Check that each block has at most ``buf_size`` elements | |
| for block in c: | |
| assert_(len(block.flat) <= (buf_size or els)) | |
| # Check that the arrayterator is sliced correctly | |
| assert_(np.all(c.__array__() == d)) | |
| # Check that all elements are iterated correctly | |
| assert_(list(c.flat) == list(d.flat)) | |
| if __name__ == '__main__': | |
| from numpy.testing import run_module_suite | |
| run_module_suite() | |