| import warnings |
| from itertools import chain |
|
|
| import pytest |
|
|
| from sklearn import config_context |
| from sklearn.utils._chunking import gen_even_slices, get_chunk_n_rows |
| from sklearn.utils._testing import assert_array_equal |
|
|
|
|
| def test_gen_even_slices(): |
| |
| some_range = range(10) |
| joined_range = list(chain(*[some_range[slice] for slice in gen_even_slices(10, 3)])) |
| assert_array_equal(some_range, joined_range) |
|
|
|
|
| @pytest.mark.parametrize( |
| ("row_bytes", "max_n_rows", "working_memory", "expected"), |
| [ |
| (1024, None, 1, 1024), |
| (1024, None, 0.99999999, 1023), |
| (1023, None, 1, 1025), |
| (1025, None, 1, 1023), |
| (1024, None, 2, 2048), |
| (1024, 7, 1, 7), |
| (1024 * 1024, None, 1, 1), |
| ], |
| ) |
| def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("error", UserWarning) |
| actual = get_chunk_n_rows( |
| row_bytes=row_bytes, |
| max_n_rows=max_n_rows, |
| working_memory=working_memory, |
| ) |
|
|
| assert actual == expected |
| assert type(actual) is type(expected) |
| with config_context(working_memory=working_memory): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("error", UserWarning) |
| actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows) |
| assert actual == expected |
| assert type(actual) is type(expected) |
|
|
|
|
| def test_get_chunk_n_rows_warns(): |
| """Check that warning is raised when working_memory is too low.""" |
| row_bytes = 1024 * 1024 + 1 |
| max_n_rows = None |
| working_memory = 1 |
| expected = 1 |
|
|
| warn_msg = ( |
| "Could not adhere to working_memory config. Currently 1MiB, 2MiB required." |
| ) |
| with pytest.warns(UserWarning, match=warn_msg): |
| actual = get_chunk_n_rows( |
| row_bytes=row_bytes, |
| max_n_rows=max_n_rows, |
| working_memory=working_memory, |
| ) |
|
|
| assert actual == expected |
| assert type(actual) is type(expected) |
|
|
| with config_context(working_memory=working_memory): |
| with pytest.warns(UserWarning, match=warn_msg): |
| actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows) |
| assert actual == expected |
| assert type(actual) is type(expected) |
|
|