| import hashlib |
| import io |
| import os |
| import re |
| import shutil |
| import tempfile |
| import warnings |
| from functools import partial |
| from importlib import resources |
| from pathlib import Path |
| from pickle import dumps, loads |
| from unittest.mock import Mock |
| from urllib.error import HTTPError |
| from urllib.parse import urlparse |
|
|
| import numpy as np |
| import pytest |
|
|
| from sklearn.datasets import ( |
| clear_data_home, |
| fetch_file, |
| get_data_home, |
| load_breast_cancer, |
| load_diabetes, |
| load_digits, |
| load_files, |
| load_iris, |
| load_linnerud, |
| load_sample_image, |
| load_sample_images, |
| load_wine, |
| ) |
| from sklearn.datasets._base import ( |
| RemoteFileMetadata, |
| _derive_folder_and_filename_from_url, |
| _fetch_remote, |
| load_csv_data, |
| load_gzip_compressed_csv_data, |
| ) |
| from sklearn.datasets.tests.test_common import check_as_frame |
| from sklearn.preprocessing import scale |
| from sklearn.utils import Bunch |
|
|
|
|
| class _DummyPath: |
| """Minimal class that implements the os.PathLike interface.""" |
|
|
| def __init__(self, path): |
| self.path = path |
|
|
| def __fspath__(self): |
| return self.path |
|
|
|
|
| def _remove_dir(path): |
| if os.path.isdir(path): |
| shutil.rmtree(path) |
|
|
|
|
| @pytest.fixture(scope="module") |
| def data_home(tmpdir_factory): |
| tmp_file = str(tmpdir_factory.mktemp("scikit_learn_data_home_test")) |
| yield tmp_file |
| _remove_dir(tmp_file) |
|
|
|
|
| @pytest.fixture(scope="module") |
| def load_files_root(tmpdir_factory): |
| tmp_file = str(tmpdir_factory.mktemp("scikit_learn_load_files_test")) |
| yield tmp_file |
| _remove_dir(tmp_file) |
|
|
|
|
| @pytest.fixture |
| def test_category_dir_1(load_files_root): |
| test_category_dir1 = tempfile.mkdtemp(dir=load_files_root) |
| sample_file = tempfile.NamedTemporaryFile(dir=test_category_dir1, delete=False) |
| sample_file.write(b"Hello World!\n") |
| sample_file.close() |
| yield str(test_category_dir1) |
| _remove_dir(test_category_dir1) |
|
|
|
|
| @pytest.fixture |
| def test_category_dir_2(load_files_root): |
| test_category_dir2 = tempfile.mkdtemp(dir=load_files_root) |
| yield str(test_category_dir2) |
| _remove_dir(test_category_dir2) |
|
|
|
|
| @pytest.mark.parametrize("path_container", [None, Path, _DummyPath]) |
| def test_data_home(path_container, data_home): |
| |
| if path_container is not None: |
| data_home = path_container(data_home) |
| data_home = get_data_home(data_home=data_home) |
| assert data_home == data_home |
| assert os.path.exists(data_home) |
|
|
| |
| if path_container is not None: |
| data_home = path_container(data_home) |
| clear_data_home(data_home=data_home) |
| assert not os.path.exists(data_home) |
|
|
| |
| data_home = get_data_home(data_home=data_home) |
| assert os.path.exists(data_home) |
|
|
|
|
| def test_default_empty_load_files(load_files_root): |
| res = load_files(load_files_root) |
| assert len(res.filenames) == 0 |
| assert len(res.target_names) == 0 |
| assert res.DESCR is None |
|
|
|
|
| def test_default_load_files(test_category_dir_1, test_category_dir_2, load_files_root): |
| res = load_files(load_files_root) |
| assert len(res.filenames) == 1 |
| assert len(res.target_names) == 2 |
| assert res.DESCR is None |
| assert res.data == [b"Hello World!\n"] |
|
|
|
|
| def test_load_files_w_categories_desc_and_encoding( |
| test_category_dir_1, test_category_dir_2, load_files_root |
| ): |
| category = os.path.abspath(test_category_dir_1).split(os.sep).pop() |
| res = load_files( |
| load_files_root, description="test", categories=[category], encoding="utf-8" |
| ) |
|
|
| assert len(res.filenames) == 1 |
| assert len(res.target_names) == 1 |
| assert res.DESCR == "test" |
| assert res.data == ["Hello World!\n"] |
|
|
|
|
| def test_load_files_wo_load_content( |
| test_category_dir_1, test_category_dir_2, load_files_root |
| ): |
| res = load_files(load_files_root, load_content=False) |
| assert len(res.filenames) == 1 |
| assert len(res.target_names) == 2 |
| assert res.DESCR is None |
| assert res.get("data") is None |
|
|
|
|
| @pytest.mark.parametrize("allowed_extensions", ([".txt"], [".txt", ".json"])) |
| def test_load_files_allowed_extensions(tmp_path, allowed_extensions): |
| """Check the behaviour of `allowed_extension` in `load_files`.""" |
| d = tmp_path / "sub" |
| d.mkdir() |
| files = ("file1.txt", "file2.json", "file3.json", "file4.md") |
| paths = [d / f for f in files] |
| for p in paths: |
| p.write_bytes(b"hello") |
| res = load_files(tmp_path, allowed_extensions=allowed_extensions) |
| assert set([str(p) for p in paths if p.suffix in allowed_extensions]) == set( |
| res.filenames |
| ) |
|
|
|
|
| @pytest.mark.parametrize( |
| "filename, expected_n_samples, expected_n_features, expected_target_names", |
| [ |
| ("wine_data.csv", 178, 13, ["class_0", "class_1", "class_2"]), |
| ("iris.csv", 150, 4, ["setosa", "versicolor", "virginica"]), |
| ("breast_cancer.csv", 569, 30, ["malignant", "benign"]), |
| ], |
| ) |
| def test_load_csv_data( |
| filename, expected_n_samples, expected_n_features, expected_target_names |
| ): |
| actual_data, actual_target, actual_target_names = load_csv_data(filename) |
| assert actual_data.shape[0] == expected_n_samples |
| assert actual_data.shape[1] == expected_n_features |
| assert actual_target.shape[0] == expected_n_samples |
| np.testing.assert_array_equal(actual_target_names, expected_target_names) |
|
|
|
|
| def test_load_csv_data_with_descr(): |
| data_file_name = "iris.csv" |
| descr_file_name = "iris.rst" |
|
|
| res_without_descr = load_csv_data(data_file_name=data_file_name) |
| res_with_descr = load_csv_data( |
| data_file_name=data_file_name, descr_file_name=descr_file_name |
| ) |
| assert len(res_with_descr) == 4 |
| assert len(res_without_descr) == 3 |
|
|
| np.testing.assert_array_equal(res_with_descr[0], res_without_descr[0]) |
| np.testing.assert_array_equal(res_with_descr[1], res_without_descr[1]) |
| np.testing.assert_array_equal(res_with_descr[2], res_without_descr[2]) |
|
|
| assert res_with_descr[-1].startswith(".. _iris_dataset:") |
|
|
|
|
| @pytest.mark.parametrize( |
| "filename, kwargs, expected_shape", |
| [ |
| ("diabetes_data_raw.csv.gz", {}, [442, 10]), |
| ("diabetes_target.csv.gz", {}, [442]), |
| ("digits.csv.gz", {"delimiter": ","}, [1797, 65]), |
| ], |
| ) |
| def test_load_gzip_compressed_csv_data(filename, kwargs, expected_shape): |
| actual_data = load_gzip_compressed_csv_data(filename, **kwargs) |
| assert actual_data.shape == tuple(expected_shape) |
|
|
|
|
| def test_load_gzip_compressed_csv_data_with_descr(): |
| data_file_name = "diabetes_target.csv.gz" |
| descr_file_name = "diabetes.rst" |
|
|
| expected_data = load_gzip_compressed_csv_data(data_file_name=data_file_name) |
| actual_data, descr = load_gzip_compressed_csv_data( |
| data_file_name=data_file_name, |
| descr_file_name=descr_file_name, |
| ) |
|
|
| np.testing.assert_array_equal(actual_data, expected_data) |
| assert descr.startswith(".. _diabetes_dataset:") |
|
|
|
|
| def test_load_sample_images(): |
| try: |
| res = load_sample_images() |
| assert len(res.images) == 2 |
| assert len(res.filenames) == 2 |
| images = res.images |
|
|
| |
| assert np.all(images[0][0, 0, :] == np.array([174, 201, 231], dtype=np.uint8)) |
| |
| assert np.all(images[1][0, 0, :] == np.array([2, 19, 13], dtype=np.uint8)) |
| assert res.DESCR |
| except ImportError: |
| warnings.warn("Could not load sample images, PIL is not available.") |
|
|
|
|
| def test_load_sample_image(): |
| try: |
| china = load_sample_image("china.jpg") |
| assert china.dtype == "uint8" |
| assert china.shape == (427, 640, 3) |
| except ImportError: |
| warnings.warn("Could not load sample images, PIL is not available.") |
|
|
|
|
| def test_load_diabetes_raw(): |
| """Test to check that we load a scaled version by default but that we can |
| get an unscaled version when setting `scaled=False`.""" |
| diabetes_raw = load_diabetes(scaled=False) |
| assert diabetes_raw.data.shape == (442, 10) |
| assert diabetes_raw.target.size, 442 |
| assert len(diabetes_raw.feature_names) == 10 |
| assert diabetes_raw.DESCR |
|
|
| diabetes_default = load_diabetes() |
|
|
| np.testing.assert_allclose( |
| scale(diabetes_raw.data) / (442**0.5), diabetes_default.data, atol=1e-04 |
| ) |
|
|
|
|
| @pytest.mark.parametrize( |
| "loader_func, data_shape, target_shape, n_target, has_descr, filenames", |
| [ |
| (load_breast_cancer, (569, 30), (569,), 2, True, ["filename"]), |
| (load_wine, (178, 13), (178,), 3, True, []), |
| (load_iris, (150, 4), (150,), 3, True, ["filename"]), |
| ( |
| load_linnerud, |
| (20, 3), |
| (20, 3), |
| 3, |
| True, |
| ["data_filename", "target_filename"], |
| ), |
| (load_diabetes, (442, 10), (442,), None, True, []), |
| (load_digits, (1797, 64), (1797,), 10, True, []), |
| (partial(load_digits, n_class=9), (1617, 64), (1617,), 10, True, []), |
| ], |
| ) |
| def test_loader(loader_func, data_shape, target_shape, n_target, has_descr, filenames): |
| bunch = loader_func() |
|
|
| assert isinstance(bunch, Bunch) |
| assert bunch.data.shape == data_shape |
| assert bunch.target.shape == target_shape |
| if hasattr(bunch, "feature_names"): |
| assert len(bunch.feature_names) == data_shape[1] |
| if n_target is not None: |
| assert len(bunch.target_names) == n_target |
| if has_descr: |
| assert bunch.DESCR |
| if filenames: |
| assert "data_module" in bunch |
| assert all( |
| [ |
| f in bunch |
| and (resources.files(bunch["data_module"]) / bunch[f]).is_file() |
| for f in filenames |
| ] |
| ) |
|
|
|
|
| @pytest.mark.parametrize( |
| "loader_func, data_dtype, target_dtype", |
| [ |
| (load_breast_cancer, np.float64, int), |
| (load_diabetes, np.float64, np.float64), |
| (load_digits, np.float64, int), |
| (load_iris, np.float64, int), |
| (load_linnerud, np.float64, np.float64), |
| (load_wine, np.float64, int), |
| ], |
| ) |
| def test_toy_dataset_frame_dtype(loader_func, data_dtype, target_dtype): |
| default_result = loader_func() |
| check_as_frame( |
| default_result, |
| loader_func, |
| expected_data_dtype=data_dtype, |
| expected_target_dtype=target_dtype, |
| ) |
|
|
|
|
| def test_loads_dumps_bunch(): |
| bunch = Bunch(x="x") |
| bunch_from_pkl = loads(dumps(bunch)) |
| bunch_from_pkl.x = "y" |
| assert bunch_from_pkl["x"] == bunch_from_pkl.x |
|
|
|
|
| def test_bunch_pickle_generated_with_0_16_and_read_with_0_17(): |
| bunch = Bunch(key="original") |
| |
| |
| |
| |
| |
| |
| |
| bunch.__dict__["key"] = "set from __dict__" |
| bunch_from_pkl = loads(dumps(bunch)) |
| |
| assert bunch_from_pkl.key == "original" |
| assert bunch_from_pkl["key"] == "original" |
| |
| |
| bunch_from_pkl.key = "changed" |
| assert bunch_from_pkl.key == "changed" |
| assert bunch_from_pkl["key"] == "changed" |
|
|
|
|
| def test_bunch_dir(): |
| |
| data = load_iris() |
| assert "data" in dir(data) |
|
|
|
|
| def test_load_boston_error(): |
| """Check that we raise the ethical warning when trying to import `load_boston`.""" |
| msg = "The Boston housing prices dataset has an ethical problem" |
| with pytest.raises(ImportError, match=msg): |
| from sklearn.datasets import load_boston |
|
|
| |
| msg = "cannot import name 'non_existing_function' from 'sklearn.datasets'" |
| with pytest.raises(ImportError, match=msg): |
| from sklearn.datasets import non_existing_function |
|
|
|
|
| def test_fetch_remote_raise_warnings_with_invalid_url(monkeypatch): |
| """Check retry mechanism in _fetch_remote.""" |
|
|
| url = "https://scikit-learn.org/this_file_does_not_exist.tar.gz" |
| invalid_remote_file = RemoteFileMetadata("invalid_file", url, None) |
| urlretrieve_mock = Mock( |
| side_effect=HTTPError( |
| url=url, code=404, msg="Not Found", hdrs=None, fp=io.BytesIO() |
| ) |
| ) |
| monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) |
|
|
| with pytest.warns(UserWarning, match="Retry downloading") as record: |
| with pytest.raises(HTTPError, match="HTTP Error 404"): |
| _fetch_remote(invalid_remote_file, n_retries=3, delay=0) |
|
|
| assert urlretrieve_mock.call_count == 4 |
|
|
| for r in record: |
| assert str(r.message) == f"Retry downloading from url: {url}" |
| assert len(record) == 3 |
|
|
|
|
| def test_derive_folder_and_filename_from_url(): |
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/file.tar.gz" |
| ) |
| assert folder == "example.com" |
| assert filename == "file.tar.gz" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/نمونه نماینده.data" |
| ) |
| assert folder == "example.com" |
| assert filename == "نمونه-نماینده.data" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/path/to-/.file.tar.gz" |
| ) |
| assert folder == "example.com/path_to" |
| assert filename == "file.tar.gz" |
|
|
| folder, filename = _derive_folder_and_filename_from_url("https://example.com/") |
| assert folder == "example.com" |
| assert filename == "downloaded_file" |
|
|
| folder, filename = _derive_folder_and_filename_from_url("https://example.com") |
| assert folder == "example.com" |
| assert filename == "downloaded_file" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/path/@to/data.json?param=value" |
| ) |
| assert folder == "example.com/path_to" |
| assert filename == "data.json" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/path/@@to._/-_.data.json.#anchor" |
| ) |
| assert folder == "example.com/path_to" |
| assert filename == "data.json" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com//some_file.txt" |
| ) |
| assert folder == "example.com" |
| assert filename == "some_file.txt" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "http://example/../some_file.txt" |
| ) |
| assert folder == "example" |
| assert filename == "some_file.txt" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/!.'.,/some_file.txt" |
| ) |
| assert folder == "example.com" |
| assert filename == "some_file.txt" |
|
|
| folder, filename = _derive_folder_and_filename_from_url( |
| "https://example.com/a/!.'.,/b/some_file.txt" |
| ) |
| assert folder == "example.com/a_b" |
| assert filename == "some_file.txt" |
|
|
| folder, filename = _derive_folder_and_filename_from_url("https://example.com/!.'.,") |
| assert folder == "example.com" |
| assert filename == "downloaded_file" |
|
|
| with pytest.raises(ValueError, match="Invalid URL"): |
| _derive_folder_and_filename_from_url("https:/../") |
|
|
|
|
| def _mock_urlretrieve(server_side): |
| def _urlretrieve_mock(url, local_path): |
| server_root = Path(server_side) |
| file_path = urlparse(url).path.strip("/") |
| if not (server_root / file_path).exists(): |
| raise HTTPError(url, 404, "Not Found", None, None) |
| shutil.copy(server_root / file_path, local_path) |
|
|
| return Mock(side_effect=_urlretrieve_mock) |
|
|
|
|
| def test_fetch_file_using_data_home(monkeypatch, tmpdir): |
| tmpdir = Path(tmpdir) |
| server_side = tmpdir / "server_side" |
| server_side.mkdir() |
| data_file = server_side / "data.jsonl" |
| server_data = '{"a": 1, "b": 2}\n' |
| data_file.write_text(server_data, encoding="utf-8") |
|
|
| server_subfolder = server_side / "subfolder" |
| server_subfolder.mkdir() |
| other_data_file = server_subfolder / "other_file.txt" |
| other_data_file.write_text("Some important text data.", encoding="utf-8") |
|
|
| data_home = tmpdir / "data_home" |
| data_home.mkdir() |
|
|
| urlretrieve_mock = _mock_urlretrieve(server_side) |
| monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) |
|
|
| monkeypatch.setattr( |
| "sklearn.datasets._base.get_data_home", Mock(return_value=data_home) |
| ) |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", |
| ) |
| assert fetched_file_path == data_home / "example.com" / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
|
|
| fetched_file_path = fetch_file( |
| "https://example.com/subfolder/other_file.txt", |
| ) |
| assert ( |
| fetched_file_path == data_home / "example.com" / "subfolder" / "other_file.txt" |
| ) |
| assert fetched_file_path.read_text(encoding="utf-8") == other_data_file.read_text( |
| "utf-8" |
| ) |
|
|
| expected_warning_msg = re.escape( |
| "Retry downloading from url: https://example.com/subfolder/invalid.txt" |
| ) |
| with pytest.raises(HTTPError): |
| with pytest.warns(match=expected_warning_msg): |
| fetch_file( |
| "https://example.com/subfolder/invalid.txt", |
| delay=0, |
| ) |
|
|
| local_subfolder = data_home / "example.com" / "subfolder" |
| assert sorted(local_subfolder.iterdir()) == [local_subfolder / "other_file.txt"] |
|
|
|
|
| def test_fetch_file_without_sha256(monkeypatch, tmpdir): |
| server_side = tmpdir.mkdir("server_side") |
| data_file = Path(server_side / "data.jsonl") |
| server_data = '{"a": 1, "b": 2}\n' |
| data_file.write_text(server_data, encoding="utf-8") |
|
|
| client_side = tmpdir.mkdir("client_side") |
|
|
| urlretrieve_mock = _mock_urlretrieve(server_side) |
| monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) |
|
|
| |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", |
| folder=client_side, |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 1 |
|
|
| |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", |
| folder=client_side, |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 1 |
|
|
| |
| fetched_file_path.unlink() |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", |
| folder=client_side, |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 2 |
|
|
|
|
| def test_fetch_file_with_sha256(monkeypatch, tmpdir): |
| server_side = tmpdir.mkdir("server_side") |
| data_file = Path(server_side / "data.jsonl") |
| server_data = '{"a": 1, "b": 2}\n' |
| data_file.write_text(server_data, encoding="utf-8") |
| expected_sha256 = hashlib.sha256(data_file.read_bytes()).hexdigest() |
|
|
| client_side = tmpdir.mkdir("client_side") |
|
|
| urlretrieve_mock = _mock_urlretrieve(server_side) |
| monkeypatch.setattr("sklearn.datasets._base.urlretrieve", urlretrieve_mock) |
|
|
| |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 1 |
|
|
| |
| |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 1 |
|
|
| |
| fetched_file_path.write_text("corrupted contents", encoding="utf-8") |
| expected_msg = ( |
| r"SHA256 checksum of existing local file data.jsonl " |
| rf"\(.*\) differs from expected \({expected_sha256}\): " |
| r"re-downloading from https://example.com/data.jsonl \." |
| ) |
| with pytest.warns(match=expected_msg): |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 2 |
|
|
| |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 2 |
|
|
| |
| fetched_file_path.unlink() |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", folder=client_side, sha256=expected_sha256 |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 3 |
|
|
| |
| fetched_file_path = fetch_file( |
| "https://example.com/data.jsonl", |
| folder=client_side, |
| ) |
| assert fetched_file_path == client_side / "data.jsonl" |
| assert fetched_file_path.read_text(encoding="utf-8") == server_data |
| assert urlretrieve_mock.call_count == 3 |
|
|
| |
| non_matching_sha256 = "deadbabecafebeef" |
| expected_warning_msg = "differs from expected" |
| expected_error_msg = re.escape( |
| f"The SHA256 checksum of data.jsonl ({expected_sha256}) differs from " |
| f"expected ({non_matching_sha256})." |
| ) |
| with pytest.raises(OSError, match=expected_error_msg): |
| with pytest.warns(match=expected_warning_msg): |
| fetch_file( |
| "https://example.com/data.jsonl", |
| folder=client_side, |
| sha256=non_matching_sha256, |
| ) |
|
|