| |
| |
|
|
| import builtins |
| import platform |
| import sys |
| from contextlib import suppress |
| from functools import wraps |
| from os import environ |
| from unittest import SkipTest |
|
|
| import joblib |
| import numpy as np |
| import pytest |
| from _pytest.doctest import DoctestItem |
| from threadpoolctl import threadpool_limits |
|
|
| from sklearn import set_config |
| from sklearn._min_dependencies import PYTEST_MIN_VERSION |
| from sklearn.datasets import ( |
| fetch_20newsgroups, |
| fetch_20newsgroups_vectorized, |
| fetch_california_housing, |
| fetch_covtype, |
| fetch_kddcup99, |
| fetch_lfw_pairs, |
| fetch_lfw_people, |
| fetch_olivetti_faces, |
| fetch_rcv1, |
| fetch_species_distributions, |
| ) |
| from sklearn.utils._testing import get_pytest_filterwarning_lines |
| from sklearn.utils.fixes import ( |
| _IS_32BIT, |
| np_base_version, |
| parse_version, |
| sp_version, |
| ) |
|
|
| if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION): |
| raise ImportError( |
| f"Your version of pytest is too old. Got version {pytest.__version__}, you" |
| f" should have pytest >= {PYTEST_MIN_VERSION} installed." |
| ) |
|
|
| scipy_datasets_require_network = sp_version >= parse_version("1.10") |
|
|
|
|
| def raccoon_face_or_skip(): |
| |
| if scipy_datasets_require_network: |
| run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
| if not run_network_tests: |
| raise SkipTest("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0") |
|
|
| try: |
| import pooch |
| except ImportError: |
| raise SkipTest("test requires pooch to be installed") |
|
|
| from scipy.datasets import face |
| else: |
| from scipy.misc import face |
|
|
| return face(gray=True) |
|
|
|
|
| dataset_fetchers = { |
| "fetch_20newsgroups_fxt": fetch_20newsgroups, |
| "fetch_20newsgroups_vectorized_fxt": fetch_20newsgroups_vectorized, |
| "fetch_california_housing_fxt": fetch_california_housing, |
| "fetch_covtype_fxt": fetch_covtype, |
| "fetch_kddcup99_fxt": fetch_kddcup99, |
| "fetch_lfw_pairs_fxt": fetch_lfw_pairs, |
| "fetch_lfw_people_fxt": fetch_lfw_people, |
| "fetch_olivetti_faces_fxt": fetch_olivetti_faces, |
| "fetch_rcv1_fxt": fetch_rcv1, |
| "fetch_species_distributions_fxt": fetch_species_distributions, |
| } |
|
|
| if scipy_datasets_require_network: |
| dataset_fetchers["raccoon_face_fxt"] = raccoon_face_or_skip |
|
|
| _SKIP32_MARK = pytest.mark.skipif( |
| environ.get("SKLEARN_RUN_FLOAT32_TESTS", "0") != "1", |
| reason="Set SKLEARN_RUN_FLOAT32_TESTS=1 to run float32 dtype tests", |
| ) |
|
|
|
|
| |
| @pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64]) |
| def global_dtype(request): |
| yield request.param |
|
|
|
|
| def _fetch_fixture(f): |
| """Fetch dataset (download if missing and requested by environment).""" |
| download_if_missing = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
|
|
| @wraps(f) |
| def wrapped(*args, **kwargs): |
| kwargs["download_if_missing"] = download_if_missing |
| try: |
| return f(*args, **kwargs) |
| except OSError as e: |
| if str(e) != "Data not found and `download_if_missing` is False": |
| raise |
| pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0") |
|
|
| return pytest.fixture(lambda: wrapped) |
|
|
|
|
| |
| fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups) |
| fetch_20newsgroups_vectorized_fxt = _fetch_fixture(fetch_20newsgroups_vectorized) |
| fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing) |
| fetch_covtype_fxt = _fetch_fixture(fetch_covtype) |
| fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99) |
| fetch_lfw_pairs_fxt = _fetch_fixture(fetch_lfw_pairs) |
| fetch_lfw_people_fxt = _fetch_fixture(fetch_lfw_people) |
| fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces) |
| fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1) |
| fetch_species_distributions_fxt = _fetch_fixture(fetch_species_distributions) |
| raccoon_face_fxt = pytest.fixture(raccoon_face_or_skip) |
|
|
|
|
| def pytest_collection_modifyitems(config, items): |
| """Called after collect is completed. |
| |
| Parameters |
| ---------- |
| config : pytest config |
| items : list of collected items |
| """ |
| run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
| skip_network = pytest.mark.skip( |
| reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0" |
| ) |
|
|
| |
| |
| dataset_features_set = set(dataset_fetchers) |
| datasets_to_download = set() |
|
|
| for item in items: |
| if isinstance(item, DoctestItem) and "fetch_" in item.name: |
| fetcher_function_name = item.name.split(".")[-1] |
| dataset_fetchers_key = f"{fetcher_function_name}_fxt" |
| dataset_to_fetch = set([dataset_fetchers_key]) & dataset_features_set |
| elif not hasattr(item, "fixturenames"): |
| continue |
| else: |
| item_fixtures = set(item.fixturenames) |
| dataset_to_fetch = item_fixtures & dataset_features_set |
|
|
| if not dataset_to_fetch: |
| continue |
|
|
| if run_network_tests: |
| datasets_to_download |= dataset_to_fetch |
| else: |
| |
| item.add_marker(skip_network) |
|
|
| |
| |
| |
| worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0") |
| if worker_id == "gw0" and run_network_tests: |
| for name in datasets_to_download: |
| with suppress(SkipTest): |
| dataset_fetchers[name]() |
|
|
| for item in items: |
| |
| if ( |
| item.name.endswith("GradientBoostingClassifier") |
| and platform.machine() == "aarch64" |
| ): |
| marker = pytest.mark.xfail( |
| reason=( |
| "know failure. See " |
| "https://github.com/scikit-learn/scikit-learn/issues/17797" |
| ) |
| ) |
| item.add_marker(marker) |
|
|
| skip_doctests = False |
| try: |
| import matplotlib |
| except ImportError: |
| skip_doctests = True |
| reason = "matplotlib is required to run the doctests" |
|
|
| if _IS_32BIT: |
| reason = "doctest are only run when the default numpy int is 64 bits." |
| skip_doctests = True |
| elif sys.platform.startswith("win32"): |
| reason = ( |
| "doctests are not run for Windows because numpy arrays " |
| "repr is inconsistent across platforms." |
| ) |
| skip_doctests = True |
|
|
| if np_base_version < parse_version("2"): |
| |
| |
| |
| reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2" |
| skip_doctests = True |
|
|
| if sp_version < parse_version("1.14"): |
| reason = "Scipy sparse matrix repr has changed in scipy 1.14" |
| skip_doctests = True |
|
|
| |
| |
| |
| for item in items: |
| if isinstance(item, DoctestItem): |
| item.dtest.globs = {} |
|
|
| if skip_doctests: |
| skip_marker = pytest.mark.skip(reason=reason) |
|
|
| for item in items: |
| if isinstance(item, DoctestItem): |
| |
| |
| |
| |
| if item.name != "sklearn._config.config_context": |
| item.add_marker(skip_marker) |
| try: |
| import PIL |
|
|
| pillow_installed = True |
| except ImportError: |
| pillow_installed = False |
|
|
| if not pillow_installed: |
| skip_marker = pytest.mark.skip(reason="pillow (or PIL) not installed!") |
| for item in items: |
| if item.name in [ |
| "sklearn.feature_extraction.image.PatchExtractor", |
| "sklearn.feature_extraction.image.extract_patches_2d", |
| ]: |
| item.add_marker(skip_marker) |
|
|
|
|
| @pytest.fixture(scope="function") |
| def pyplot(): |
| """Setup and teardown fixture for matplotlib. |
| |
| This fixture checks if we can import matplotlib. If not, the tests will be |
| skipped. Otherwise, we close the figures before and after running the |
| functions. |
| |
| Returns |
| ------- |
| pyplot : module |
| The ``matplotlib.pyplot`` module. |
| """ |
| pyplot = pytest.importorskip("matplotlib.pyplot") |
| pyplot.close("all") |
| yield pyplot |
| pyplot.close("all") |
|
|
|
|
| def pytest_generate_tests(metafunc): |
| """Parametrization of global_random_seed fixture |
| |
| based on the SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable. |
| |
| The goal of this fixture is to prevent tests that use it to be sensitive |
| to a specific seed value while still being deterministic by default. |
| |
| See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED |
| variable for instructions on how to use this fixture. |
| |
| https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed |
| |
| """ |
| |
| |
| |
| |
| RANDOM_SEED_RANGE = list(range(100)) |
| random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED") |
|
|
| default_random_seeds = [42] |
|
|
| if random_seed_var is None: |
| random_seeds = default_random_seeds |
| elif random_seed_var == "all": |
| random_seeds = RANDOM_SEED_RANGE |
| else: |
| if "-" in random_seed_var: |
| start, stop = random_seed_var.split("-") |
| random_seeds = list(range(int(start), int(stop) + 1)) |
| else: |
| random_seeds = [int(random_seed_var)] |
|
|
| if min(random_seeds) < 0 or max(random_seeds) > 99: |
| raise ValueError( |
| "The value(s) of the environment variable " |
| "SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] " |
| f"(or 'all'), got: {random_seed_var}" |
| ) |
|
|
| if "global_random_seed" in metafunc.fixturenames: |
| metafunc.parametrize("global_random_seed", random_seeds) |
|
|
|
|
| def pytest_configure(config): |
| |
| try: |
| import matplotlib |
|
|
| matplotlib.use("agg") |
| except ImportError: |
| pass |
|
|
| allowed_parallelism = joblib.cpu_count(only_physical_cores=True) |
| xdist_worker_count = environ.get("PYTEST_XDIST_WORKER_COUNT") |
| if xdist_worker_count is not None: |
| |
| |
| allowed_parallelism = max(allowed_parallelism // int(xdist_worker_count), 1) |
| threadpool_limits(allowed_parallelism) |
|
|
| if environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0": |
| |
| |
| |
| for line in get_pytest_filterwarning_lines(): |
| config.addinivalue_line("filterwarnings", line) |
|
|
|
|
| @pytest.fixture |
| def hide_available_pandas(monkeypatch): |
| """Pretend pandas was not installed.""" |
| import_orig = builtins.__import__ |
|
|
| def mocked_import(name, *args, **kwargs): |
| if name == "pandas": |
| raise ImportError() |
| return import_orig(name, *args, **kwargs) |
|
|
| monkeypatch.setattr(builtins, "__import__", mocked_import) |
|
|
|
|
| @pytest.fixture |
| def print_changed_only_false(): |
| """Set `print_changed_only` to False for the duration of the test.""" |
| set_config(print_changed_only=False) |
| yield |
| set_config(print_changed_only=True) |
|
|