| import pytest | |
| from sklearn.utils._mask import safe_mask | |
| from sklearn.utils.fixes import CSR_CONTAINERS | |
| from sklearn.utils.validation import check_random_state | |
| def test_safe_mask(csr_container): | |
| random_state = check_random_state(0) | |
| X = random_state.rand(5, 4) | |
| X_csr = csr_container(X) | |
| mask = [False, False, True, True, True] | |
| mask = safe_mask(X, mask) | |
| assert X[mask].shape[0] == 3 | |
| mask = safe_mask(X_csr, mask) | |
| assert X_csr[mask].shape[0] == 3 | |