| """Test the 20news downloader, if the data is available, |
| or if specifically requested via environment variable |
| (e.g. for CI jobs).""" |
|
|
| from functools import partial |
| from unittest.mock import patch |
|
|
| import numpy as np |
| import pytest |
| import scipy.sparse as sp |
|
|
| from sklearn.datasets.tests.test_common import ( |
| check_as_frame, |
| check_pandas_dependency_message, |
| check_return_X_y, |
| ) |
| from sklearn.preprocessing import normalize |
| from sklearn.utils._testing import assert_allclose_dense_sparse |
|
|
|
|
| def test_20news(fetch_20newsgroups_fxt): |
| data = fetch_20newsgroups_fxt(subset="all", shuffle=False) |
| assert data.DESCR.startswith(".. _20newsgroups_dataset:") |
|
|
| |
| data2cats = fetch_20newsgroups_fxt( |
| subset="all", categories=data.target_names[-1:-3:-1], shuffle=False |
| ) |
| |
| |
| assert data2cats.target_names == data.target_names[-2:] |
| |
| assert np.unique(data2cats.target).tolist() == [0, 1] |
|
|
| |
| assert len(data2cats.filenames) == len(data2cats.target) |
| assert len(data2cats.filenames) == len(data2cats.data) |
|
|
| |
| |
| entry1 = data2cats.data[0] |
| category = data2cats.target_names[data2cats.target[0]] |
| label = data.target_names.index(category) |
| entry2 = data.data[np.where(data.target == label)[0][0]] |
| assert entry1 == entry2 |
|
|
| |
| X, y = fetch_20newsgroups_fxt(subset="all", shuffle=False, return_X_y=True) |
| assert len(X) == len(data.data) |
| assert y.shape == data.target.shape |
|
|
|
|
| def test_20news_length_consistency(fetch_20newsgroups_fxt): |
| """Checks the length consistencies within the bunch |
| |
| This is a non-regression test for a bug present in 0.16.1. |
| """ |
| |
| data = fetch_20newsgroups_fxt(subset="all") |
| assert len(data["data"]) == len(data.data) |
| assert len(data["target"]) == len(data.target) |
| assert len(data["filenames"]) == len(data.filenames) |
|
|
|
|
| def test_20news_vectorized(fetch_20newsgroups_vectorized_fxt): |
| |
| bunch = fetch_20newsgroups_vectorized_fxt(subset="train") |
| assert sp.issparse(bunch.data) and bunch.data.format == "csr" |
| assert bunch.data.shape == (11314, 130107) |
| assert bunch.target.shape[0] == 11314 |
| assert bunch.data.dtype == np.float64 |
| assert bunch.DESCR.startswith(".. _20newsgroups_dataset:") |
|
|
| |
| bunch = fetch_20newsgroups_vectorized_fxt(subset="test") |
| assert sp.issparse(bunch.data) and bunch.data.format == "csr" |
| assert bunch.data.shape == (7532, 130107) |
| assert bunch.target.shape[0] == 7532 |
| assert bunch.data.dtype == np.float64 |
| assert bunch.DESCR.startswith(".. _20newsgroups_dataset:") |
|
|
| |
| fetch_func = partial(fetch_20newsgroups_vectorized_fxt, subset="test") |
| check_return_X_y(bunch, fetch_func) |
|
|
| |
| bunch = fetch_20newsgroups_vectorized_fxt(subset="all") |
| assert sp.issparse(bunch.data) and bunch.data.format == "csr" |
| assert bunch.data.shape == (11314 + 7532, 130107) |
| assert bunch.target.shape[0] == 11314 + 7532 |
| assert bunch.data.dtype == np.float64 |
| assert bunch.DESCR.startswith(".. _20newsgroups_dataset:") |
|
|
|
|
| def test_20news_normalization(fetch_20newsgroups_vectorized_fxt): |
| X = fetch_20newsgroups_vectorized_fxt(normalize=False) |
| X_ = fetch_20newsgroups_vectorized_fxt(normalize=True) |
| X_norm = X_["data"][:100] |
| X = X["data"][:100] |
|
|
| assert_allclose_dense_sparse(X_norm, normalize(X)) |
| assert np.allclose(np.linalg.norm(X_norm.todense(), axis=1), 1) |
|
|
|
|
| def test_20news_as_frame(fetch_20newsgroups_vectorized_fxt): |
| pd = pytest.importorskip("pandas") |
|
|
| bunch = fetch_20newsgroups_vectorized_fxt(as_frame=True) |
| check_as_frame(bunch, fetch_20newsgroups_vectorized_fxt) |
|
|
| frame = bunch.frame |
| assert frame.shape == (11314, 130108) |
| assert all([isinstance(col, pd.SparseDtype) for col in bunch.data.dtypes]) |
|
|
| |
| for expected_feature in [ |
| "beginner", |
| "beginners", |
| "beginning", |
| "beginnings", |
| "begins", |
| "begley", |
| "begone", |
| ]: |
| assert expected_feature in frame.keys() |
| assert "category_class" in frame.keys() |
| assert bunch.target.name == "category_class" |
|
|
|
|
| def test_as_frame_no_pandas(fetch_20newsgroups_vectorized_fxt, hide_available_pandas): |
| check_pandas_dependency_message(fetch_20newsgroups_vectorized_fxt) |
|
|
|
|
| def test_outdated_pickle(fetch_20newsgroups_vectorized_fxt): |
| with patch("os.path.exists") as mock_is_exist: |
| with patch("joblib.load") as mock_load: |
| |
| mock_is_exist.return_value = True |
| |
| mock_load.return_value = ("X", "y") |
| err_msg = "The cached dataset located in" |
| with pytest.raises(ValueError, match=err_msg): |
| fetch_20newsgroups_vectorized_fxt(as_frame=True) |
|
|