| | import re |
| | from inspect import signature |
| | from typing import Optional |
| |
|
| | import pytest |
| |
|
| | |
| | from sklearn.experimental import ( |
| | enable_halving_search_cv, |
| | enable_iterative_imputer, |
| | ) |
| | from sklearn.utils.discovery import all_displays, all_estimators, all_functions |
| |
|
| | numpydoc_validation = pytest.importorskip("numpydoc.validate") |
| |
|
| |
|
| | def get_all_methods(): |
| | estimators = all_estimators() |
| | displays = all_displays() |
| | for name, Klass in estimators + displays: |
| | if name.startswith("_"): |
| | |
| | continue |
| | methods = [] |
| | for name in dir(Klass): |
| | if name.startswith("_"): |
| | continue |
| | method_obj = getattr(Klass, name) |
| | if hasattr(method_obj, "__call__") or isinstance(method_obj, property): |
| | methods.append(name) |
| | methods.append(None) |
| |
|
| | for method in sorted(methods, key=str): |
| | yield Klass, method |
| |
|
| |
|
| | def get_all_functions_names(): |
| | functions = all_functions() |
| | for _, func in functions: |
| | |
| | if "utils.fixes" not in func.__module__: |
| | yield f"{func.__module__}.{func.__name__}" |
| |
|
| |
|
| | def filter_errors(errors, method, Klass=None): |
| | """ |
| | Ignore some errors based on the method type. |
| | |
| | These rules are specific for scikit-learn.""" |
| | for code, message in errors: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if code in ["RT02", "GL01", "GL02"]: |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if code in ("PR02", "GL08") and Klass is not None and method is not None: |
| | method_obj = getattr(Klass, method) |
| | if isinstance(method_obj, property): |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if method is not None and code in ["EX01", "SA01", "ES01"]: |
| | continue |
| | yield code, message |
| |
|
| |
|
| | def repr_errors(res, Klass=None, method: Optional[str] = None) -> str: |
| | """Pretty print original docstring and the obtained errors |
| | |
| | Parameters |
| | ---------- |
| | res : dict |
| | result of numpydoc.validate.validate |
| | Klass : {Estimator, Display, None} |
| | estimator object or None |
| | method : str |
| | if estimator is not None, either the method name or None. |
| | |
| | Returns |
| | ------- |
| | str |
| | String representation of the error. |
| | """ |
| | if method is None: |
| | if hasattr(Klass, "__init__"): |
| | method = "__init__" |
| | elif Klass is None: |
| | raise ValueError("At least one of Klass, method should be provided") |
| | else: |
| | raise NotImplementedError |
| |
|
| | if Klass is not None: |
| | obj = getattr(Klass, method) |
| | try: |
| | obj_signature = str(signature(obj)) |
| | except TypeError: |
| | |
| | obj_signature = ( |
| | "\nParsing of the method signature failed, " |
| | "possibly because this is a property." |
| | ) |
| |
|
| | obj_name = Klass.__name__ + "." + method |
| | else: |
| | obj_signature = "" |
| | obj_name = method |
| |
|
| | msg = "\n\n" + "\n\n".join( |
| | [ |
| | str(res["file"]), |
| | obj_name + obj_signature, |
| | res["docstring"], |
| | "# Errors", |
| | "\n".join( |
| | " - {}: {}".format(code, message) for code, message in res["errors"] |
| | ), |
| | ] |
| | ) |
| | return msg |
| |
|
| |
|
| | @pytest.mark.parametrize("function_name", get_all_functions_names()) |
| | def test_function_docstring(function_name, request): |
| | """Check function docstrings using numpydoc.""" |
| | res = numpydoc_validation.validate(function_name) |
| |
|
| | res["errors"] = list(filter_errors(res["errors"], method="function")) |
| |
|
| | if res["errors"]: |
| | msg = repr_errors(res, method=f"Tested function: {function_name}") |
| |
|
| | raise ValueError(msg) |
| |
|
| |
|
| | @pytest.mark.parametrize("Klass, method", get_all_methods()) |
| | def test_docstring(Klass, method, request): |
| | base_import_path = Klass.__module__ |
| | import_path = [base_import_path, Klass.__name__] |
| | if method is not None: |
| | import_path.append(method) |
| |
|
| | import_path = ".".join(import_path) |
| |
|
| | res = numpydoc_validation.validate(import_path) |
| |
|
| | res["errors"] = list(filter_errors(res["errors"], method, Klass=Klass)) |
| |
|
| | if res["errors"]: |
| | msg = repr_errors(res, Klass, method) |
| |
|
| | raise ValueError(msg) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | import sys |
| |
|
| | parser = argparse.ArgumentParser(description="Validate docstring with numpydoc.") |
| | parser.add_argument("import_path", help="Import path to validate") |
| |
|
| | args = parser.parse_args() |
| |
|
| | res = numpydoc_validation.validate(args.import_path) |
| |
|
| | import_path_sections = args.import_path.split(".") |
| | |
| | |
| | |
| | |
| | if len(import_path_sections) >= 2 and re.match( |
| | r"(?:[A-Z][a-z]*)+", import_path_sections[-2] |
| | ): |
| | method = import_path_sections[-1] |
| | else: |
| | method = None |
| |
|
| | res["errors"] = list(filter_errors(res["errors"], method)) |
| |
|
| | if res["errors"]: |
| | msg = repr_errors(res, method=args.import_path) |
| |
|
| | print(msg) |
| | sys.exit(1) |
| | else: |
| | print("All docstring checks passed for {}!".format(args.import_path)) |
| |
|