spam-classifier
/
venv
/lib
/python3.11
/site-packages
/sklearn
/tests
/test_docstring_parameters.py
| # Authors: The scikit-learn developers | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| import importlib | |
| import inspect | |
| import os | |
| import warnings | |
| from inspect import signature | |
| from pkgutil import walk_packages | |
| import numpy as np | |
| import pytest | |
| import sklearn | |
| from sklearn import metrics | |
| from sklearn.datasets import make_classification | |
| from sklearn.ensemble import StackingClassifier, StackingRegressor | |
| # make it possible to discover experimental estimators when calling `all_estimators` | |
| from sklearn.experimental import ( | |
| enable_halving_search_cv, # noqa | |
| enable_iterative_imputer, # noqa | |
| ) | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.preprocessing import FunctionTransformer | |
| from sklearn.utils import all_estimators | |
| from sklearn.utils._test_common.instance_generator import _construct_instances | |
| from sklearn.utils._testing import ( | |
| _get_func_name, | |
| assert_docstring_consistency, | |
| check_docstring_parameters, | |
| ignore_warnings, | |
| skip_if_no_numpydoc, | |
| ) | |
| from sklearn.utils.deprecation import _is_deprecated | |
| from sklearn.utils.estimator_checks import ( | |
| _enforce_estimator_tags_X, | |
| _enforce_estimator_tags_y, | |
| ) | |
| # walk_packages() ignores DeprecationWarnings, now we need to ignore | |
| # FutureWarnings | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", FutureWarning) | |
| # mypy error: Module has no attribute "__path__" | |
| sklearn_path = [os.path.dirname(sklearn.__file__)] | |
| PUBLIC_MODULES = set( | |
| [ | |
| pckg[1] | |
| for pckg in walk_packages(prefix="sklearn.", path=sklearn_path) | |
| if not ("._" in pckg[1] or ".tests." in pckg[1]) | |
| ] | |
| ) | |
| # functions to ignore args / docstring of | |
| # TODO(1.7): remove "sklearn.utils._joblib" | |
| _DOCSTRING_IGNORES = [ | |
| "sklearn.utils.deprecation.load_mlcomp", | |
| "sklearn.pipeline.make_pipeline", | |
| "sklearn.pipeline.make_union", | |
| "sklearn.utils.extmath.safe_sparse_dot", | |
| "sklearn.utils._joblib", | |
| "HalfBinomialLoss", | |
| ] | |
| # Methods where y param should be ignored if y=None by default | |
| _METHODS_IGNORE_NONE_Y = [ | |
| "fit", | |
| "score", | |
| "fit_predict", | |
| "fit_transform", | |
| "partial_fit", | |
| "predict", | |
| ] | |
| def test_docstring_parameters(): | |
| # Test module docstring formatting | |
| # Skip test if numpydoc is not found | |
| pytest.importorskip( | |
| "numpydoc", reason="numpydoc is required to test the docstrings" | |
| ) | |
| # XXX unreached code as of v0.22 | |
| from numpydoc import docscrape | |
| incorrect = [] | |
| for name in PUBLIC_MODULES: | |
| if name.endswith(".conftest"): | |
| # pytest tooling, not part of the scikit-learn API | |
| continue | |
| if name == "sklearn.utils.fixes": | |
| # We cannot always control these docstrings | |
| continue | |
| with warnings.catch_warnings(record=True): | |
| module = importlib.import_module(name) | |
| classes = inspect.getmembers(module, inspect.isclass) | |
| # Exclude non-scikit-learn classes | |
| classes = [cls for cls in classes if cls[1].__module__.startswith("sklearn")] | |
| for cname, cls in classes: | |
| this_incorrect = [] | |
| if cname in _DOCSTRING_IGNORES or cname.startswith("_"): | |
| continue | |
| if inspect.isabstract(cls): | |
| continue | |
| with warnings.catch_warnings(record=True) as w: | |
| cdoc = docscrape.ClassDoc(cls) | |
| if len(w): | |
| raise RuntimeError( | |
| "Error for __init__ of %s in %s:\n%s" % (cls, name, w[0]) | |
| ) | |
| # Skip checks on deprecated classes | |
| if _is_deprecated(cls.__new__): | |
| continue | |
| this_incorrect += check_docstring_parameters(cls.__init__, cdoc) | |
| for method_name in cdoc.methods: | |
| method = getattr(cls, method_name) | |
| if _is_deprecated(method): | |
| continue | |
| param_ignore = None | |
| # Now skip docstring test for y when y is None | |
| # by default for API reason | |
| if method_name in _METHODS_IGNORE_NONE_Y: | |
| sig = signature(method) | |
| if "y" in sig.parameters and sig.parameters["y"].default is None: | |
| param_ignore = ["y"] # ignore y for fit and score | |
| result = check_docstring_parameters(method, ignore=param_ignore) | |
| this_incorrect += result | |
| incorrect += this_incorrect | |
| functions = inspect.getmembers(module, inspect.isfunction) | |
| # Exclude imported functions | |
| functions = [fn for fn in functions if fn[1].__module__ == name] | |
| for fname, func in functions: | |
| # Don't test private methods / functions | |
| if fname.startswith("_"): | |
| continue | |
| if fname == "configuration" and name.endswith("setup"): | |
| continue | |
| name_ = _get_func_name(func) | |
| if not any(d in name_ for d in _DOCSTRING_IGNORES) and not _is_deprecated( | |
| func | |
| ): | |
| incorrect += check_docstring_parameters(func) | |
| msg = "\n".join(incorrect) | |
| if len(incorrect) > 0: | |
| raise AssertionError("Docstring Error:\n" + msg) | |
| def _construct_searchcv_instance(SearchCV): | |
| return SearchCV(LogisticRegression(), {"C": [0.1, 1]}) | |
| def _construct_compose_pipeline_instance(Estimator): | |
| # Minimal / degenerate instances: only useful to test the docstrings. | |
| if Estimator.__name__ == "ColumnTransformer": | |
| return Estimator(transformers=[("transformer", "passthrough", [0, 1])]) | |
| elif Estimator.__name__ == "Pipeline": | |
| return Estimator(steps=[("clf", LogisticRegression())]) | |
| elif Estimator.__name__ == "FeatureUnion": | |
| return Estimator(transformer_list=[("transformer", FunctionTransformer())]) | |
| def _construct_sparse_coder(Estimator): | |
| # XXX: hard-coded assumption that n_features=3 | |
| dictionary = np.array( | |
| [[0, 1, 0], [-1, -1, 2], [1, 1, 1], [0, 1, 1], [0, 2, 1]], | |
| dtype=np.float64, | |
| ) | |
| return Estimator(dictionary=dictionary) | |
| def test_fit_docstring_attributes(name, Estimator): | |
| pytest.importorskip("numpydoc") | |
| from numpydoc import docscrape | |
| doc = docscrape.ClassDoc(Estimator) | |
| attributes = doc["Attributes"] | |
| if Estimator.__name__ in ( | |
| "HalvingRandomSearchCV", | |
| "RandomizedSearchCV", | |
| "HalvingGridSearchCV", | |
| "GridSearchCV", | |
| ): | |
| est = _construct_searchcv_instance(Estimator) | |
| elif Estimator.__name__ in ( | |
| "ColumnTransformer", | |
| "Pipeline", | |
| "FeatureUnion", | |
| ): | |
| est = _construct_compose_pipeline_instance(Estimator) | |
| elif Estimator.__name__ == "SparseCoder": | |
| est = _construct_sparse_coder(Estimator) | |
| elif Estimator.__name__ == "FrozenEstimator": | |
| X, y = make_classification(n_samples=20, n_features=5, random_state=0) | |
| est = Estimator(LogisticRegression().fit(X, y)) | |
| else: | |
| # TODO(devtools): use _tested_estimators instead of all_estimators in the | |
| # decorator | |
| est = next(_construct_instances(Estimator)) | |
| if Estimator.__name__ == "SelectKBest": | |
| est.set_params(k=2) | |
| elif Estimator.__name__ == "DummyClassifier": | |
| est.set_params(strategy="stratified") | |
| elif Estimator.__name__ == "CCA" or Estimator.__name__.startswith("PLS"): | |
| # default = 2 is invalid for single target | |
| est.set_params(n_components=1) | |
| elif Estimator.__name__ in ( | |
| "GaussianRandomProjection", | |
| "SparseRandomProjection", | |
| ): | |
| # default="auto" raises an error with the shape of `X` | |
| est.set_params(n_components=2) | |
| elif Estimator.__name__ == "TSNE": | |
| # default raises an error, perplexity must be less than n_samples | |
| est.set_params(perplexity=2) | |
| # Low max iter to speed up tests: we are only interested in checking the existence | |
| # of fitted attributes. This should be invariant to whether it has converged or not. | |
| if "max_iter" in est.get_params(): | |
| est.set_params(max_iter=2) | |
| # min value for `TSNE` is 250 | |
| if Estimator.__name__ == "TSNE": | |
| est.set_params(max_iter=250) | |
| if "random_state" in est.get_params(): | |
| est.set_params(random_state=0) | |
| # In case we want to deprecate some attributes in the future | |
| skipped_attributes = {} | |
| if Estimator.__name__.endswith("Vectorizer"): | |
| # Vectorizer require some specific input data | |
| if Estimator.__name__ in ( | |
| "CountVectorizer", | |
| "HashingVectorizer", | |
| "TfidfVectorizer", | |
| ): | |
| X = [ | |
| "This is the first document.", | |
| "This document is the second document.", | |
| "And this is the third one.", | |
| "Is this the first document?", | |
| ] | |
| elif Estimator.__name__ == "DictVectorizer": | |
| X = [{"foo": 1, "bar": 2}, {"foo": 3, "baz": 1}] | |
| y = None | |
| else: | |
| X, y = make_classification( | |
| n_samples=20, | |
| n_features=3, | |
| n_redundant=0, | |
| n_classes=2, | |
| random_state=2, | |
| ) | |
| y = _enforce_estimator_tags_y(est, y) | |
| X = _enforce_estimator_tags_X(est, X) | |
| if est.__sklearn_tags__().target_tags.one_d_labels: | |
| est.fit(y) | |
| elif est.__sklearn_tags__().target_tags.two_d_labels: | |
| est.fit(np.c_[y, y]) | |
| elif est.__sklearn_tags__().input_tags.three_d_array: | |
| est.fit(X[np.newaxis, ...], y) | |
| else: | |
| est.fit(X, y) | |
| for attr in attributes: | |
| if attr.name in skipped_attributes: | |
| continue | |
| desc = " ".join(attr.desc).lower() | |
| # As certain attributes are present "only" if a certain parameter is | |
| # provided, this checks if the word "only" is present in the attribute | |
| # description, and if not the attribute is required to be present. | |
| if "only " in desc: | |
| continue | |
| # ignore deprecation warnings | |
| with ignore_warnings(category=FutureWarning): | |
| assert hasattr(est, attr.name) | |
| fit_attr = _get_all_fitted_attributes(est) | |
| fit_attr_names = [attr.name for attr in attributes] | |
| undocumented_attrs = set(fit_attr).difference(fit_attr_names) | |
| undocumented_attrs = set(undocumented_attrs).difference(skipped_attributes) | |
| if undocumented_attrs: | |
| raise AssertionError( | |
| f"Undocumented attributes for {Estimator.__name__}: {undocumented_attrs}" | |
| ) | |
| def _get_all_fitted_attributes(estimator): | |
| "Get all the fitted attributes of an estimator including properties" | |
| # attributes | |
| fit_attr = list(estimator.__dict__.keys()) | |
| # properties | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("error", category=FutureWarning) | |
| for name in dir(estimator.__class__): | |
| obj = getattr(estimator.__class__, name) | |
| if not isinstance(obj, property): | |
| continue | |
| # ignore properties that raises an AttributeError and deprecated | |
| # properties | |
| try: | |
| getattr(estimator, name) | |
| except (AttributeError, FutureWarning): | |
| continue | |
| fit_attr.append(name) | |
| return [k for k in fit_attr if k.endswith("_") and not k.startswith("_")] | |
| def test_precision_recall_f_score_docstring_consistency(): | |
| """Check docstrings parameters of related metrics are consistent.""" | |
| metrics_to_check = [ | |
| metrics.precision_recall_fscore_support, | |
| metrics.f1_score, | |
| metrics.fbeta_score, | |
| metrics.precision_score, | |
| metrics.recall_score, | |
| ] | |
| assert_docstring_consistency( | |
| metrics_to_check, | |
| include_params=True, | |
| # "zero_division" - the reason for zero division differs between f scores, | |
| # precision and recall. | |
| exclude_params=["average", "zero_division"], | |
| ) | |
| description_regex = ( | |
| r"""This parameter is required for multiclass/multilabel targets\. | |
| If ``None``, the metrics for each class are returned\. Otherwise, this | |
| determines the type of averaging performed on the data: | |
| ``'binary'``: | |
| Only report results for the class specified by ``pos_label``\. | |
| This is applicable only if targets \(``y_\{true,pred\}``\) are binary\. | |
| ``'micro'``: | |
| Calculate metrics globally by counting the total true positives, | |
| false negatives and false positives\. | |
| ``'macro'``: | |
| Calculate metrics for each label, and find their unweighted | |
| mean\. This does not take label imbalance into account\. | |
| ``'weighted'``: | |
| Calculate metrics for each label, and find their average weighted | |
| by support \(the number of true instances for each label\)\. This | |
| alters 'macro' to account for label imbalance; it can result in an | |
| F-score that is not between precision and recall\.""" | |
| + r"[\s\w]*\.*" # optionally match additonal sentence | |
| + r""" | |
| ``'samples'``: | |
| Calculate metrics for each instance, and find their average \(only | |
| meaningful for multilabel classification where this differs from | |
| :func:`accuracy_score`\)\.""" | |
| ) | |
| assert_docstring_consistency( | |
| metrics_to_check, | |
| include_params=["average"], | |
| descr_regex_pattern=" ".join(description_regex.split()), | |
| ) | |
| def test_stacking_classifier_regressor_docstring_consistency(): | |
| """Check docstrings parameters stacking estimators are consistent.""" | |
| assert_docstring_consistency( | |
| [StackingClassifier, StackingRegressor], | |
| include_params=["cv", "n_jobs", "passthrough", "verbose"], | |
| include_attrs=True, | |
| exclude_attrs=["final_estimator_"], | |
| ) | |