spam-classifier
/
venv
/lib
/python3.11
/site-packages
/sklearn
/preprocessing
/tests
/test_common.py
| import warnings | |
| import numpy as np | |
| import pytest | |
| from sklearn.base import clone | |
| from sklearn.datasets import load_iris | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import ( | |
| MaxAbsScaler, | |
| MinMaxScaler, | |
| PowerTransformer, | |
| QuantileTransformer, | |
| RobustScaler, | |
| StandardScaler, | |
| maxabs_scale, | |
| minmax_scale, | |
| power_transform, | |
| quantile_transform, | |
| robust_scale, | |
| scale, | |
| ) | |
| from sklearn.utils._testing import assert_allclose, assert_array_equal | |
| from sklearn.utils.fixes import ( | |
| BSR_CONTAINERS, | |
| COO_CONTAINERS, | |
| CSC_CONTAINERS, | |
| CSR_CONTAINERS, | |
| DIA_CONTAINERS, | |
| DOK_CONTAINERS, | |
| LIL_CONTAINERS, | |
| ) | |
| iris = load_iris() | |
| def _get_valid_samples_by_column(X, col): | |
| """Get non NaN samples in column of X""" | |
| return X[:, [col]][~np.isnan(X[:, col])] | |
| def test_missing_value_handling( | |
| est, func, support_sparse, strictly_positive, omit_kwargs | |
| ): | |
| # check that the preprocessing method let pass nan | |
| rng = np.random.RandomState(42) | |
| X = iris.data.copy() | |
| n_missing = 50 | |
| X[ | |
| rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing) | |
| ] = np.nan | |
| if strictly_positive: | |
| X += np.nanmin(X) + 0.1 | |
| X_train, X_test = train_test_split(X, random_state=1) | |
| # sanity check | |
| assert not np.all(np.isnan(X_train), axis=0).any() | |
| assert np.any(np.isnan(X_train), axis=0).all() | |
| assert np.any(np.isnan(X_test), axis=0).all() | |
| X_test[:, 0] = np.nan # make sure this boundary case is tested | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("error", RuntimeWarning) | |
| Xt = est.fit(X_train).transform(X_test) | |
| # ensure no warnings are raised | |
| # missing values should still be missing, and only them | |
| assert_array_equal(np.isnan(Xt), np.isnan(X_test)) | |
| # check that the function leads to the same results as the class | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("error", RuntimeWarning) | |
| Xt_class = est.transform(X_train) | |
| kwargs = est.get_params() | |
| # remove the parameters which should be omitted because they | |
| # are not defined in the counterpart function of the preprocessing class | |
| for kwarg in omit_kwargs: | |
| _ = kwargs.pop(kwarg) | |
| Xt_func = func(X_train, **kwargs) | |
| assert_array_equal(np.isnan(Xt_func), np.isnan(Xt_class)) | |
| assert_allclose(Xt_func[~np.isnan(Xt_func)], Xt_class[~np.isnan(Xt_class)]) | |
| # check that the inverse transform keep NaN | |
| Xt_inv = est.inverse_transform(Xt) | |
| assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test)) | |
| # FIXME: we can introduce equal_nan=True in recent version of numpy. | |
| # For the moment which just check that non-NaN values are almost equal. | |
| assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)]) | |
| for i in range(X.shape[1]): | |
| # train only on non-NaN | |
| est.fit(_get_valid_samples_by_column(X_train, i)) | |
| # check transforming with NaN works even when training without NaN | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("error", RuntimeWarning) | |
| Xt_col = est.transform(X_test[:, [i]]) | |
| assert_allclose(Xt_col, Xt[:, [i]]) | |
| # check non-NaN is handled as before - the 1st column is all nan | |
| if not np.isnan(X_test[:, i]).all(): | |
| Xt_col_nonan = est.transform(_get_valid_samples_by_column(X_test, i)) | |
| assert_array_equal(Xt_col_nonan, Xt_col[~np.isnan(Xt_col.squeeze())]) | |
| if support_sparse: | |
| est_dense = clone(est) | |
| est_sparse = clone(est) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("error", RuntimeWarning) | |
| Xt_dense = est_dense.fit(X_train).transform(X_test) | |
| Xt_inv_dense = est_dense.inverse_transform(Xt_dense) | |
| for sparse_container in ( | |
| BSR_CONTAINERS | |
| + COO_CONTAINERS | |
| + CSC_CONTAINERS | |
| + CSR_CONTAINERS | |
| + DIA_CONTAINERS | |
| + DOK_CONTAINERS | |
| + LIL_CONTAINERS | |
| ): | |
| # check that the dense and sparse inputs lead to the same results | |
| # precompute the matrix to avoid catching side warnings | |
| X_train_sp = sparse_container(X_train) | |
| X_test_sp = sparse_container(X_test) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", PendingDeprecationWarning) | |
| warnings.simplefilter("error", RuntimeWarning) | |
| Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) | |
| assert_allclose(Xt_sp.toarray(), Xt_dense) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", PendingDeprecationWarning) | |
| warnings.simplefilter("error", RuntimeWarning) | |
| Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) | |
| assert_allclose(Xt_inv_sp.toarray(), Xt_inv_dense) | |
| def test_missing_value_pandas_na_support(est, func): | |
| # Test pandas IntegerArray with pd.NA | |
| pd = pytest.importorskip("pandas") | |
| X = np.array( | |
| [ | |
| [1, 2, 3, np.nan, np.nan, 4, 5, 1], | |
| [np.nan, np.nan, 8, 4, 6, np.nan, np.nan, 8], | |
| [1, 2, 3, 4, 5, 6, 7, 8], | |
| ] | |
| ).T | |
| # Creates dataframe with IntegerArrays with pd.NA | |
| X_df = pd.DataFrame(X, dtype="Int16", columns=["a", "b", "c"]) | |
| X_df["c"] = X_df["c"].astype("int") | |
| X_trans = est.fit_transform(X) | |
| X_df_trans = est.fit_transform(X_df) | |
| assert_allclose(X_trans, X_df_trans) | |