| | import atexit |
| | import os |
| | import warnings |
| |
|
| | import numpy as np |
| | import pytest |
| | from scipy import sparse |
| |
|
| | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
| | from sklearn.tree import DecisionTreeClassifier |
| | from sklearn.utils._testing import ( |
| | TempMemmap, |
| | _convert_container, |
| | _delete_folder, |
| | _get_warnings_filters_info_list, |
| | assert_allclose, |
| | assert_allclose_dense_sparse, |
| | assert_docstring_consistency, |
| | assert_run_python_script_without_output, |
| | check_docstring_parameters, |
| | create_memmap_backed_data, |
| | ignore_warnings, |
| | raises, |
| | set_random_state, |
| | skip_if_no_numpydoc, |
| | turn_warnings_into_errors, |
| | ) |
| | from sklearn.utils.deprecation import deprecated |
| | from sklearn.utils.fixes import ( |
| | _IS_WASM, |
| | CSC_CONTAINERS, |
| | CSR_CONTAINERS, |
| | parse_version, |
| | sp_version, |
| | ) |
| | from sklearn.utils.metaestimators import available_if |
| |
|
| |
|
| | def test_set_random_state(): |
| | lda = LinearDiscriminantAnalysis() |
| | tree = DecisionTreeClassifier() |
| | |
| | set_random_state(lda, 3) |
| | set_random_state(tree, 3) |
| | assert tree.random_state == 3 |
| |
|
| |
|
| | @pytest.mark.parametrize("csr_container", CSC_CONTAINERS) |
| | def test_assert_allclose_dense_sparse(csr_container): |
| | x = np.arange(9).reshape(3, 3) |
| | msg = "Not equal to tolerance " |
| | y = csr_container(x) |
| | for X in [x, y]: |
| | |
| | with pytest.raises(AssertionError, match=msg): |
| | assert_allclose_dense_sparse(X, X * 2) |
| | assert_allclose_dense_sparse(X, X) |
| |
|
| | with pytest.raises(ValueError, match="Can only compare two sparse"): |
| | assert_allclose_dense_sparse(x, y) |
| |
|
| | A = sparse.diags(np.ones(5), offsets=0).tocsr() |
| | B = csr_container(np.ones((1, 5))) |
| | with pytest.raises(AssertionError, match="Arrays are not equal"): |
| | assert_allclose_dense_sparse(B, A) |
| |
|
| |
|
| | def test_ignore_warning(): |
| | |
| | |
| | def _warning_function(): |
| | warnings.warn("deprecation warning", DeprecationWarning) |
| |
|
| | def _multiple_warning_function(): |
| | warnings.warn("deprecation warning", DeprecationWarning) |
| | warnings.warn("deprecation warning") |
| |
|
| | |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("error") |
| |
|
| | ignore_warnings(_warning_function) |
| | ignore_warnings(_warning_function, category=DeprecationWarning) |
| |
|
| | with pytest.warns(DeprecationWarning): |
| | ignore_warnings(_warning_function, category=UserWarning)() |
| |
|
| | with pytest.warns() as record: |
| | ignore_warnings(_multiple_warning_function, category=FutureWarning)() |
| | assert len(record) == 2 |
| | assert isinstance(record[0].message, DeprecationWarning) |
| | assert isinstance(record[1].message, UserWarning) |
| |
|
| | with pytest.warns() as record: |
| | ignore_warnings(_multiple_warning_function, category=UserWarning)() |
| | assert len(record) == 1 |
| | assert isinstance(record[0].message, DeprecationWarning) |
| |
|
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("error") |
| |
|
| | ignore_warnings(_warning_function, category=(DeprecationWarning, UserWarning)) |
| |
|
| | |
| | @ignore_warnings |
| | def decorator_no_warning(): |
| | _warning_function() |
| | _multiple_warning_function() |
| |
|
| | @ignore_warnings(category=(DeprecationWarning, UserWarning)) |
| | def decorator_no_warning_multiple(): |
| | _multiple_warning_function() |
| |
|
| | @ignore_warnings(category=DeprecationWarning) |
| | def decorator_no_deprecation_warning(): |
| | _warning_function() |
| |
|
| | @ignore_warnings(category=UserWarning) |
| | def decorator_no_user_warning(): |
| | _warning_function() |
| |
|
| | @ignore_warnings(category=DeprecationWarning) |
| | def decorator_no_deprecation_multiple_warning(): |
| | _multiple_warning_function() |
| |
|
| | @ignore_warnings(category=UserWarning) |
| | def decorator_no_user_multiple_warning(): |
| | _multiple_warning_function() |
| |
|
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("error") |
| |
|
| | decorator_no_warning() |
| | decorator_no_warning_multiple() |
| | decorator_no_deprecation_warning() |
| |
|
| | with pytest.warns(DeprecationWarning): |
| | decorator_no_user_warning() |
| | with pytest.warns(UserWarning): |
| | decorator_no_deprecation_multiple_warning() |
| | with pytest.warns(DeprecationWarning): |
| | decorator_no_user_multiple_warning() |
| |
|
| | |
| | def context_manager_no_warning(): |
| | with ignore_warnings(): |
| | _warning_function() |
| |
|
| | def context_manager_no_warning_multiple(): |
| | with ignore_warnings(category=(DeprecationWarning, UserWarning)): |
| | _multiple_warning_function() |
| |
|
| | def context_manager_no_deprecation_warning(): |
| | with ignore_warnings(category=DeprecationWarning): |
| | _warning_function() |
| |
|
| | def context_manager_no_user_warning(): |
| | with ignore_warnings(category=UserWarning): |
| | _warning_function() |
| |
|
| | def context_manager_no_deprecation_multiple_warning(): |
| | with ignore_warnings(category=DeprecationWarning): |
| | _multiple_warning_function() |
| |
|
| | def context_manager_no_user_multiple_warning(): |
| | with ignore_warnings(category=UserWarning): |
| | _multiple_warning_function() |
| |
|
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("error") |
| |
|
| | context_manager_no_warning() |
| | context_manager_no_warning_multiple() |
| | context_manager_no_deprecation_warning() |
| |
|
| | with pytest.warns(DeprecationWarning): |
| | context_manager_no_user_warning() |
| | with pytest.warns(UserWarning): |
| | context_manager_no_deprecation_multiple_warning() |
| | with pytest.warns(DeprecationWarning): |
| | context_manager_no_user_multiple_warning() |
| |
|
| | |
| | warning_class = UserWarning |
| | match = "'obj' should be a callable.+you should use 'category=UserWarning'" |
| |
|
| | with pytest.raises(ValueError, match=match): |
| | silence_warnings_func = ignore_warnings(warning_class)(_warning_function) |
| | silence_warnings_func() |
| |
|
| | with pytest.raises(ValueError, match=match): |
| |
|
| | @ignore_warnings(warning_class) |
| | def test(): |
| | pass |
| |
|
| |
|
| | |
| |
|
| |
|
| | def f_ok(a, b): |
| | """Function f |
| | |
| | Parameters |
| | ---------- |
| | a : int |
| | Parameter a |
| | b : float |
| | Parameter b |
| | |
| | Returns |
| | ------- |
| | c : list |
| | Parameter c |
| | """ |
| | c = a + b |
| | return c |
| |
|
| |
|
| | def f_bad_sections(a, b): |
| | """Function f |
| | |
| | Parameters |
| | ---------- |
| | a : int |
| | Parameter a |
| | b : float |
| | Parameter b |
| | |
| | Results |
| | ------- |
| | c : list |
| | Parameter c |
| | """ |
| | c = a + b |
| | return c |
| |
|
| |
|
| | def f_bad_order(b, a): |
| | """Function f |
| | |
| | Parameters |
| | ---------- |
| | a : int |
| | Parameter a |
| | b : float |
| | Parameter b |
| | |
| | Returns |
| | ------- |
| | c : list |
| | Parameter c |
| | """ |
| | c = a + b |
| | return c |
| |
|
| |
|
| | def f_too_many_param_docstring(a, b): |
| | """Function f |
| | |
| | Parameters |
| | ---------- |
| | a : int |
| | Parameter a |
| | b : int |
| | Parameter b |
| | c : int |
| | Parameter c |
| | |
| | Returns |
| | ------- |
| | d : list |
| | Parameter c |
| | """ |
| | d = a + b |
| | return d |
| |
|
| |
|
| | def f_missing(a, b): |
| | """Function f |
| | |
| | Parameters |
| | ---------- |
| | a : int |
| | Parameter a |
| | |
| | Returns |
| | ------- |
| | c : list |
| | Parameter c |
| | """ |
| | c = a + b |
| | return c |
| |
|
| |
|
| | def f_check_param_definition(a, b, c, d, e): |
| | """Function f |
| | |
| | Parameters |
| | ---------- |
| | a: int |
| | Parameter a |
| | b: |
| | Parameter b |
| | c : |
| | This is parsed correctly in numpydoc 1.2 |
| | d:int |
| | Parameter d |
| | e |
| | No typespec is allowed without colon |
| | """ |
| | return a + b + c + d |
| |
|
| |
|
| | class Klass: |
| | def f_missing(self, X, y): |
| | pass |
| |
|
| | def f_bad_sections(self, X, y): |
| | """Function f |
| | |
| | Parameter |
| | --------- |
| | a : int |
| | Parameter a |
| | b : float |
| | Parameter b |
| | |
| | Results |
| | ------- |
| | c : list |
| | Parameter c |
| | """ |
| | pass |
| |
|
| |
|
| | class MockEst: |
| | def __init__(self): |
| | """MockEstimator""" |
| |
|
| | def fit(self, X, y): |
| | return X |
| |
|
| | def predict(self, X): |
| | return X |
| |
|
| | def predict_proba(self, X): |
| | return X |
| |
|
| | def score(self, X): |
| | return 1.0 |
| |
|
| |
|
| | class MockMetaEstimator: |
| | def __init__(self, delegate): |
| | """MetaEstimator to check if doctest on delegated methods work. |
| | |
| | Parameters |
| | --------- |
| | delegate : estimator |
| | Delegated estimator. |
| | """ |
| | self.delegate = delegate |
| |
|
| | @available_if(lambda self: hasattr(self.delegate, "predict")) |
| | def predict(self, X): |
| | """This is available only if delegate has predict. |
| | |
| | Parameters |
| | ---------- |
| | y : ndarray |
| | Parameter y |
| | """ |
| | return self.delegate.predict(X) |
| |
|
| | @available_if(lambda self: hasattr(self.delegate, "score")) |
| | @deprecated("Testing a deprecated delegated method") |
| | def score(self, X): |
| | """This is available only if delegate has score. |
| | |
| | Parameters |
| | --------- |
| | y : ndarray |
| | Parameter y |
| | """ |
| |
|
| | @available_if(lambda self: hasattr(self.delegate, "predict_proba")) |
| | def predict_proba(self, X): |
| | """This is available only if delegate has predict_proba. |
| | |
| | Parameters |
| | --------- |
| | X : ndarray |
| | Parameter X |
| | """ |
| | return X |
| |
|
| | @deprecated("Testing deprecated function with wrong params") |
| | def fit(self, X, y): |
| | """Incorrect docstring but should not be tested""" |
| |
|
| |
|
| | @skip_if_no_numpydoc |
| | def test_check_docstring_parameters(): |
| | incorrect = check_docstring_parameters(f_ok) |
| | assert incorrect == [] |
| | incorrect = check_docstring_parameters(f_ok, ignore=["b"]) |
| | assert incorrect == [] |
| | incorrect = check_docstring_parameters(f_missing, ignore=["b"]) |
| | assert incorrect == [] |
| | with pytest.raises(RuntimeError, match="Unknown section Results"): |
| | check_docstring_parameters(f_bad_sections) |
| | with pytest.raises(RuntimeError, match="Unknown section Parameter"): |
| | check_docstring_parameters(Klass.f_bad_sections) |
| |
|
| | incorrect = check_docstring_parameters(f_check_param_definition) |
| | mock_meta = MockMetaEstimator(delegate=MockEst()) |
| | mock_meta_name = mock_meta.__class__.__name__ |
| | assert incorrect == [ |
| | ( |
| | "sklearn.utils.tests.test_testing.f_check_param_definition There " |
| | "was no space between the param name and colon ('a: int')" |
| | ), |
| | ( |
| | "sklearn.utils.tests.test_testing.f_check_param_definition There " |
| | "was no space between the param name and colon ('b:')" |
| | ), |
| | ( |
| | "sklearn.utils.tests.test_testing.f_check_param_definition There " |
| | "was no space between the param name and colon ('d:int')" |
| | ), |
| | ] |
| |
|
| | messages = [ |
| | [ |
| | "In function: sklearn.utils.tests.test_testing.f_bad_order", |
| | ( |
| | "There's a parameter name mismatch in function docstring w.r.t." |
| | " function signature, at index 0 diff: 'b' != 'a'" |
| | ), |
| | "Full diff:", |
| | "- ['b', 'a']", |
| | "+ ['a', 'b']", |
| | ], |
| | [ |
| | "In function: " |
| | + "sklearn.utils.tests.test_testing.f_too_many_param_docstring", |
| | ( |
| | "Parameters in function docstring have more items w.r.t. function" |
| | " signature, first extra item: c" |
| | ), |
| | "Full diff:", |
| | "- ['a', 'b']", |
| | "+ ['a', 'b', 'c']", |
| | "? +++++", |
| | ], |
| | [ |
| | "In function: sklearn.utils.tests.test_testing.f_missing", |
| | ( |
| | "Parameters in function docstring have less items w.r.t. function" |
| | " signature, first missing item: b" |
| | ), |
| | "Full diff:", |
| | "- ['a', 'b']", |
| | "+ ['a']", |
| | ], |
| | [ |
| | "In function: sklearn.utils.tests.test_testing.Klass.f_missing", |
| | ( |
| | "Parameters in function docstring have less items w.r.t. function" |
| | " signature, first missing item: X" |
| | ), |
| | "Full diff:", |
| | "- ['X', 'y']", |
| | "+ []", |
| | ], |
| | [ |
| | "In function: " |
| | + f"sklearn.utils.tests.test_testing.{mock_meta_name}.predict", |
| | ( |
| | "There's a parameter name mismatch in function docstring w.r.t." |
| | " function signature, at index 0 diff: 'X' != 'y'" |
| | ), |
| | "Full diff:", |
| | "- ['X']", |
| | "? ^", |
| | "+ ['y']", |
| | "? ^", |
| | ], |
| | [ |
| | "In function: " |
| | + f"sklearn.utils.tests.test_testing.{mock_meta_name}." |
| | + "predict_proba", |
| | "potentially wrong underline length... ", |
| | "Parameters ", |
| | "--------- in ", |
| | ], |
| | [ |
| | "In function: " |
| | + f"sklearn.utils.tests.test_testing.{mock_meta_name}.score", |
| | "potentially wrong underline length... ", |
| | "Parameters ", |
| | "--------- in ", |
| | ], |
| | [ |
| | "In function: " + f"sklearn.utils.tests.test_testing.{mock_meta_name}.fit", |
| | ( |
| | "Parameters in function docstring have less items w.r.t. function" |
| | " signature, first missing item: X" |
| | ), |
| | "Full diff:", |
| | "- ['X', 'y']", |
| | "+ []", |
| | ], |
| | ] |
| |
|
| | for msg, f in zip( |
| | messages, |
| | [ |
| | f_bad_order, |
| | f_too_many_param_docstring, |
| | f_missing, |
| | Klass.f_missing, |
| | mock_meta.predict, |
| | mock_meta.predict_proba, |
| | mock_meta.score, |
| | mock_meta.fit, |
| | ], |
| | ): |
| | incorrect = check_docstring_parameters(f) |
| | assert msg == incorrect, '\n"%s"\n not in \n"%s"' % (msg, incorrect) |
| |
|
| |
|
| | def f_one(a, b): |
| | """Function one. |
| | |
| | Parameters |
| | ---------- |
| | a : int, float |
| | Parameter a. |
| | Second line. |
| | |
| | b : str |
| | Parameter b. |
| | |
| | Returns |
| | ------- |
| | c : int |
| | Returning |
| | |
| | d : int |
| | Returning |
| | """ |
| | pass |
| |
|
| |
|
| | def f_two(a, b): |
| | """Function two. |
| | |
| | Parameters |
| | ---------- |
| | a : int, float |
| | Parameter a. |
| | Second line. |
| | |
| | b : str |
| | Parameter bb. |
| | |
| | e : int |
| | Extra parameter. |
| | |
| | Returns |
| | ------- |
| | c : int |
| | Returning |
| | |
| | d : int |
| | Returning |
| | """ |
| | pass |
| |
|
| |
|
| | def f_three(a, b): |
| | """Function two. |
| | |
| | Parameters |
| | ---------- |
| | a : int, float |
| | Parameter a. |
| | |
| | b : str |
| | Parameter B! |
| | |
| | e : |
| | Extra parameter. |
| | |
| | Returns |
| | ------- |
| | c : int |
| | Returning. |
| | |
| | d : int |
| | Returning |
| | """ |
| | pass |
| |
|
| |
|
| | @skip_if_no_numpydoc |
| | def test_assert_docstring_consistency_object_type(): |
| | """Check error raised when `objects` incorrect type.""" |
| | with pytest.raises(TypeError, match="All 'objects' must be one of"): |
| | assert_docstring_consistency(["string", f_one]) |
| |
|
| |
|
| | @skip_if_no_numpydoc |
| | @pytest.mark.parametrize( |
| | "objects, kwargs, error", |
| | [ |
| | ( |
| | [f_one, f_two], |
| | {"include_params": ["a"], "exclude_params": ["b"]}, |
| | "The 'exclude_params' argument", |
| | ), |
| | ( |
| | [f_one, f_two], |
| | {"include_returns": False, "exclude_returns": ["c"]}, |
| | "The 'exclude_returns' argument", |
| | ), |
| | ], |
| | ) |
| | def test_assert_docstring_consistency_arg_checks(objects, kwargs, error): |
| | """Check `assert_docstring_consistency` argument checking correct.""" |
| | with pytest.raises(TypeError, match=error): |
| | assert_docstring_consistency(objects, **kwargs) |
| |
|
| |
|
| | @skip_if_no_numpydoc |
| | @pytest.mark.parametrize( |
| | "objects, kwargs, error, warn", |
| | [ |
| | pytest.param( |
| | [f_one, f_two], {"include_params": ["a"]}, "", "", id="whitespace" |
| | ), |
| | pytest.param([f_one, f_two], {"include_returns": True}, "", "", id="incl_all"), |
| | pytest.param( |
| | [f_one, f_two, f_three], |
| | {"include_params": ["a"]}, |
| | ( |
| | r"The description of Parameter 'a' is inconsistent between " |
| | r"\['f_one',\n'f_two'\]" |
| | ), |
| | "", |
| | id="2-1 group", |
| | ), |
| | pytest.param( |
| | [f_one, f_two, f_three], |
| | {"include_params": ["b"]}, |
| | ( |
| | r"The description of Parameter 'b' is inconsistent between " |
| | r"\['f_one'\] and\n\['f_two'\] and" |
| | ), |
| | "", |
| | id="1-1-1 group", |
| | ), |
| | pytest.param( |
| | [f_two, f_three], |
| | {"include_params": ["e"]}, |
| | ( |
| | r"The type specification of Parameter 'e' is inconsistent between\n" |
| | r"\['f_two'\] and" |
| | ), |
| | "", |
| | id="empty type", |
| | ), |
| | pytest.param( |
| | [f_one, f_two], |
| | {"include_params": True, "exclude_params": ["b"]}, |
| | "", |
| | r"Checking was skipped for Parameters: \['e'\]", |
| | id="skip warn", |
| | ), |
| | ], |
| | ) |
| | def test_assert_docstring_consistency(objects, kwargs, error, warn): |
| | """Check `assert_docstring_consistency` gives correct results.""" |
| | if error: |
| | with pytest.raises(AssertionError, match=error): |
| | assert_docstring_consistency(objects, **kwargs) |
| | elif warn: |
| | with pytest.warns(UserWarning, match=warn): |
| | assert_docstring_consistency(objects, **kwargs) |
| | else: |
| | assert_docstring_consistency(objects, **kwargs) |
| |
|
| |
|
| | def f_four(labels): |
| | """Function four. |
| | |
| | Parameters |
| | ---------- |
| | |
| | labels : array-like, default=None |
| | The set of labels to include when `average != 'binary'`, and their |
| | order if `average is None`. Labels present in the data can be excluded. |
| | """ |
| | pass |
| |
|
| |
|
| | def f_five(labels): |
| | """Function five. |
| | |
| | Parameters |
| | ---------- |
| | |
| | labels : array-like, default=None |
| | The set of labels to include when `average != 'binary'`, and their |
| | order if `average is None`. This is an extra line. Labels present in the |
| | data can be excluded. |
| | """ |
| | pass |
| |
|
| |
|
| | def f_six(labels): |
| | """Function six. |
| | |
| | Parameters |
| | ---------- |
| | |
| | labels : array-like, default=None |
| | The group of labels to add when `average != 'binary'`, and the |
| | order if `average is None`. Labels present on them datas can be excluded. |
| | """ |
| | pass |
| |
|
| |
|
| | @skip_if_no_numpydoc |
| | def test_assert_docstring_consistency_error_msg(): |
| | """Check `assert_docstring_consistency` difference message.""" |
| | msg = r"""The description of Parameter 'labels' is inconsistent between |
| | \['f_four'\] and \['f_five'\] and \['f_six'\]: |
| | |
| | \*\*\* \['f_four'\] |
| | --- \['f_five'\] |
| | \*\*\*\*\*\*\*\*\*\*\*\*\*\*\* |
| | |
| | \*\*\* 10,25 \*\*\*\* |
| | |
| | --- 10,30 ---- |
| | |
| | 'binary'`, and their order if `average is None`. |
| | \+ This is an extra line. |
| | Labels present in the data can be excluded. |
| | |
| | \*\*\* \['f_four'\] |
| | --- \['f_six'\] |
| | \*\*\*\*\*\*\*\*\*\*\*\*\*\*\* |
| | |
| | \*\*\* 1,25 \*\*\*\* |
| | |
| | The |
| | ! set |
| | of labels to |
| | ! include |
| | when `average != 'binary'`, and |
| | ! their |
| | order if `average is None`. Labels present |
| | ! in the data |
| | can be excluded. |
| | --- 1,25 ---- |
| | |
| | The |
| | ! group |
| | of labels to |
| | ! add |
| | when `average != 'binary'`, and |
| | ! the |
| | order if `average is None`. Labels present |
| | ! on them datas |
| | can be excluded.""" |
| |
|
| | with pytest.raises(AssertionError, match=msg): |
| | assert_docstring_consistency([f_four, f_five, f_six], include_params=True) |
| |
|
| |
|
| | @skip_if_no_numpydoc |
| | def test_assert_docstring_consistency_descr_regex_pattern(): |
| | """Check `assert_docstring_consistency` `descr_regex_pattern` works.""" |
| | |
| | regex_full = ( |
| | r"The (set|group) " |
| | + r"of labels to (include|add) " |
| | + r"when `average \!\= 'binary'`, and (their|the) " |
| | + r"order if `average is None`\." |
| | + r"[\s\w]*\.* " |
| | + r"Labels present (on|in) " |
| | + r"(them|the) " |
| | + r"datas? can be excluded\." |
| | ) |
| |
|
| | assert_docstring_consistency( |
| | [f_four, f_five, f_six], |
| | include_params=True, |
| | descr_regex_pattern=" ".join(regex_full.split()), |
| | ) |
| | |
| | regex_words = r"(labels|average|binary)" |
| | assert_docstring_consistency( |
| | [f_four, f_five, f_six], |
| | include_params=True, |
| | descr_regex_pattern=" ".join(regex_words.split()), |
| | ) |
| | |
| | regex_error = r"The set of labels to include when.+" |
| | msg = r"The description of Parameter 'labels' in \['f_six'\] does not match" |
| | with pytest.raises(AssertionError, match=msg): |
| | assert_docstring_consistency( |
| | [f_four, f_five, f_six], |
| | include_params=True, |
| | descr_regex_pattern=" ".join(regex_error.split()), |
| | ) |
| |
|
| |
|
| | class RegistrationCounter: |
| | def __init__(self): |
| | self.nb_calls = 0 |
| |
|
| | def __call__(self, to_register_func): |
| | self.nb_calls += 1 |
| | assert to_register_func.func is _delete_folder |
| |
|
| |
|
| | def check_memmap(input_array, mmap_data, mmap_mode="r"): |
| | assert isinstance(mmap_data, np.memmap) |
| | writeable = mmap_mode != "r" |
| | assert mmap_data.flags.writeable is writeable |
| | np.testing.assert_array_equal(input_array, mmap_data) |
| |
|
| |
|
| | def test_tempmemmap(monkeypatch): |
| | registration_counter = RegistrationCounter() |
| | monkeypatch.setattr(atexit, "register", registration_counter) |
| |
|
| | input_array = np.ones(3) |
| | with TempMemmap(input_array) as data: |
| | check_memmap(input_array, data) |
| | temp_folder = os.path.dirname(data.filename) |
| | if os.name != "nt": |
| | assert not os.path.exists(temp_folder) |
| | assert registration_counter.nb_calls == 1 |
| |
|
| | mmap_mode = "r+" |
| | with TempMemmap(input_array, mmap_mode=mmap_mode) as data: |
| | check_memmap(input_array, data, mmap_mode=mmap_mode) |
| | temp_folder = os.path.dirname(data.filename) |
| | if os.name != "nt": |
| | assert not os.path.exists(temp_folder) |
| | assert registration_counter.nb_calls == 2 |
| |
|
| |
|
| | @pytest.mark.xfail(_IS_WASM, reason="memmap not fully supported") |
| | def test_create_memmap_backed_data(monkeypatch): |
| | registration_counter = RegistrationCounter() |
| | monkeypatch.setattr(atexit, "register", registration_counter) |
| |
|
| | input_array = np.ones(3) |
| | data = create_memmap_backed_data(input_array) |
| | check_memmap(input_array, data) |
| | assert registration_counter.nb_calls == 1 |
| |
|
| | data, folder = create_memmap_backed_data(input_array, return_folder=True) |
| | check_memmap(input_array, data) |
| | assert folder == os.path.dirname(data.filename) |
| | assert registration_counter.nb_calls == 2 |
| |
|
| | mmap_mode = "r+" |
| | data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode) |
| | check_memmap(input_array, data, mmap_mode) |
| | assert registration_counter.nb_calls == 3 |
| |
|
| | input_list = [input_array, input_array + 1, input_array + 2] |
| | mmap_data_list = create_memmap_backed_data(input_list) |
| | for input_array, data in zip(input_list, mmap_data_list): |
| | check_memmap(input_array, data) |
| | assert registration_counter.nb_calls == 4 |
| |
|
| | output_data, other = create_memmap_backed_data([input_array, "not-an-array"]) |
| | check_memmap(input_array, output_data) |
| | assert other == "not-an-array" |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "constructor_name, container_type", |
| | [ |
| | ("list", list), |
| | ("tuple", tuple), |
| | ("array", np.ndarray), |
| | ("sparse", sparse.csr_matrix), |
| | |
| | |
| | *zip(["sparse_csr", "sparse_csr_array"], CSR_CONTAINERS), |
| | *zip(["sparse_csc", "sparse_csc_array"], CSC_CONTAINERS), |
| | ("dataframe", lambda: pytest.importorskip("pandas").DataFrame), |
| | ("series", lambda: pytest.importorskip("pandas").Series), |
| | ("index", lambda: pytest.importorskip("pandas").Index), |
| | ("slice", slice), |
| | ], |
| | ) |
| | @pytest.mark.parametrize( |
| | "dtype, superdtype", |
| | [ |
| | (np.int32, np.integer), |
| | (np.int64, np.integer), |
| | (np.float32, np.floating), |
| | (np.float64, np.floating), |
| | ], |
| | ) |
| | def test_convert_container( |
| | constructor_name, |
| | container_type, |
| | dtype, |
| | superdtype, |
| | ): |
| | """Check that we convert the container to the right type of array with the |
| | right data type.""" |
| | if constructor_name in ("dataframe", "polars", "series", "polars_series", "index"): |
| | |
| | |
| | container_type = container_type() |
| | container = [0, 1] |
| |
|
| | container_converted = _convert_container( |
| | container, |
| | constructor_name, |
| | dtype=dtype, |
| | ) |
| | assert isinstance(container_converted, container_type) |
| |
|
| | if constructor_name in ("list", "tuple", "index"): |
| | |
| | |
| | assert np.issubdtype(type(container_converted[0]), superdtype) |
| | elif hasattr(container_converted, "dtype"): |
| | assert container_converted.dtype == dtype |
| | elif hasattr(container_converted, "dtypes"): |
| | assert container_converted.dtypes[0] == dtype |
| |
|
| |
|
| | def test_convert_container_categories_pandas(): |
| | pytest.importorskip("pandas") |
| | df = _convert_container( |
| | [["x"]], "dataframe", ["A"], categorical_feature_names=["A"] |
| | ) |
| | assert df.dtypes.iloc[0] == "category" |
| |
|
| |
|
| | def test_convert_container_categories_polars(): |
| | pl = pytest.importorskip("polars") |
| | df = _convert_container([["x"]], "polars", ["A"], categorical_feature_names=["A"]) |
| | assert df.schema["A"] == pl.Categorical() |
| |
|
| |
|
| | def test_convert_container_categories_pyarrow(): |
| | pa = pytest.importorskip("pyarrow") |
| | df = _convert_container([["x"]], "pyarrow", ["A"], categorical_feature_names=["A"]) |
| | assert type(df.schema[0].type) is pa.DictionaryType |
| |
|
| |
|
| | @pytest.mark.skipif( |
| | sp_version >= parse_version("1.8"), |
| | reason="sparse arrays are available as of scipy 1.8.0", |
| | ) |
| | @pytest.mark.parametrize("constructor_name", ["sparse_csr_array", "sparse_csc_array"]) |
| | @pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64]) |
| | def test_convert_container_raise_when_sparray_not_available(constructor_name, dtype): |
| | """Check that if we convert to sparse array but sparse array are not supported |
| | (scipy<1.8.0), we should raise an explicit error.""" |
| | container = [0, 1] |
| |
|
| | with pytest.raises( |
| | ValueError, |
| | match=f"only available with scipy>=1.8.0, got {sp_version}", |
| | ): |
| | _convert_container(container, constructor_name, dtype=dtype) |
| |
|
| |
|
| | def test_raises(): |
| | |
| |
|
| | |
| | with raises(TypeError): |
| | raise TypeError() |
| |
|
| | |
| | with raises(TypeError, match="how are you") as cm: |
| | raise TypeError("hello how are you") |
| | assert cm.raised_and_matched |
| |
|
| | |
| | with raises(TypeError, match=["not this one", "how are you"]) as cm: |
| | raise TypeError("hello how are you") |
| | assert cm.raised_and_matched |
| |
|
| | |
| | with pytest.raises(ValueError, match="this will be raised"): |
| | with raises(TypeError) as cm: |
| | raise ValueError("this will be raised") |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with pytest.raises(AssertionError, match="the failure message"): |
| | with raises(TypeError, err_msg="the failure message") as cm: |
| | raise ValueError() |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with pytest.raises(ValueError, match="this will be raised"): |
| | with raises(TypeError, match="this is ignored") as cm: |
| | raise ValueError("this will be raised") |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with pytest.raises( |
| | AssertionError, match="should contain one of the following patterns" |
| | ): |
| | with raises(TypeError, match="hello") as cm: |
| | raise TypeError("Bad message") |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with pytest.raises(AssertionError, match="the failure message"): |
| | with raises(TypeError, match="hello", err_msg="the failure message") as cm: |
| | raise TypeError("Bad message") |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with pytest.raises(AssertionError, match="Did not raise"): |
| | with raises(TypeError) as cm: |
| | pass |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with raises(TypeError, match="hello", may_pass=True) as cm: |
| | pass |
| | assert not cm.raised_and_matched |
| |
|
| | |
| | with raises((TypeError, ValueError)): |
| | raise TypeError() |
| | with raises((TypeError, ValueError)): |
| | raise ValueError() |
| | with pytest.raises(AssertionError): |
| | with raises((TypeError, ValueError)): |
| | pass |
| |
|
| |
|
| | def test_float32_aware_assert_allclose(): |
| | |
| | assert_allclose(np.array([1.0 + 2e-5], dtype=np.float32), 1.0) |
| | with pytest.raises(AssertionError): |
| | assert_allclose(np.array([1.0 + 2e-4], dtype=np.float32), 1.0) |
| |
|
| | |
| | |
| | assert_allclose(np.array([1.0 + 2e-8], dtype=np.float64), 1.0) |
| | with pytest.raises(AssertionError): |
| | assert_allclose(np.array([1.0 + 2e-7], dtype=np.float64), 1.0) |
| |
|
| | |
| | with pytest.raises(AssertionError): |
| | assert_allclose(np.array([1e-5], dtype=np.float32), 0.0) |
| | assert_allclose(np.array([1e-5], dtype=np.float32), 0.0, atol=2e-5) |
| |
|
| |
|
| | @pytest.mark.xfail(_IS_WASM, reason="cannot start subprocess") |
| | def test_assert_run_python_script_without_output(): |
| | code = "x = 1" |
| | assert_run_python_script_without_output(code) |
| |
|
| | code = "print('something to stdout')" |
| | with pytest.raises(AssertionError, match="Expected no output"): |
| | assert_run_python_script_without_output(code) |
| |
|
| | code = "print('something to stdout')" |
| | with pytest.raises( |
| | AssertionError, |
| | match="output was not supposed to match.+got.+something to stdout", |
| | ): |
| | assert_run_python_script_without_output(code, pattern="to.+stdout") |
| |
|
| | code = "\n".join(["import sys", "print('something to stderr', file=sys.stderr)"]) |
| | with pytest.raises( |
| | AssertionError, |
| | match="output was not supposed to match.+got.+something to stderr", |
| | ): |
| | assert_run_python_script_without_output(code, pattern="to.+stderr") |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "constructor_name", |
| | [ |
| | "sparse_csr", |
| | "sparse_csc", |
| | pytest.param( |
| | "sparse_csr_array", |
| | marks=pytest.mark.skipif( |
| | sp_version < parse_version("1.8"), |
| | reason="sparse arrays are available as of scipy 1.8.0", |
| | ), |
| | ), |
| | pytest.param( |
| | "sparse_csc_array", |
| | marks=pytest.mark.skipif( |
| | sp_version < parse_version("1.8"), |
| | reason="sparse arrays are available as of scipy 1.8.0", |
| | ), |
| | ), |
| | ], |
| | ) |
| | def test_convert_container_sparse_to_sparse(constructor_name): |
| | """Non-regression test to check that we can still convert a sparse container |
| | from a given format to another format. |
| | """ |
| | X_sparse = sparse.random(10, 10, density=0.1, format="csr") |
| | _convert_container(X_sparse, constructor_name) |
| |
|
| |
|
| | def check_warnings_as_errors(warning_info, warnings_as_errors): |
| | if warning_info.action == "error" and warnings_as_errors: |
| | with pytest.raises(warning_info.category, match=warning_info.message): |
| | warnings.warn( |
| | message=warning_info.message, |
| | category=warning_info.category, |
| | ) |
| | if warning_info.action == "ignore": |
| | with warnings.catch_warnings(record=True) as record: |
| | message = warning_info.message |
| | |
| | if "Pyarrow" in message: |
| | message = "\nPyarrow will become a required dependency" |
| |
|
| | warnings.warn( |
| | message=message, |
| | category=warning_info.category, |
| | ) |
| | assert len(record) == 0 if warnings_as_errors else 1 |
| | if record: |
| | assert str(record[0].message) == message |
| | assert record[0].category == warning_info.category |
| |
|
| |
|
| | @pytest.mark.parametrize("warning_info", _get_warnings_filters_info_list()) |
| | def test_sklearn_warnings_as_errors(warning_info): |
| | warnings_as_errors = os.environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0" |
| | check_warnings_as_errors(warning_info, warnings_as_errors=warnings_as_errors) |
| |
|
| |
|
| | @pytest.mark.parametrize("warning_info", _get_warnings_filters_info_list()) |
| | def test_turn_warnings_into_errors(warning_info): |
| | with warnings.catch_warnings(): |
| | turn_warnings_into_errors() |
| | check_warnings_as_errors(warning_info, warnings_as_errors=True) |
| |
|