| """Global configuration state and functions for management""" |
|
|
| |
| |
|
|
| import os |
| import threading |
| from contextlib import contextmanager as contextmanager |
|
|
| _global_config = { |
| "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), |
| "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), |
| "print_changed_only": True, |
| "display": "diagram", |
| "pairwise_dist_chunk_size": int( |
| os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256) |
| ), |
| "enable_cython_pairwise_dist": True, |
| "array_api_dispatch": False, |
| "transform_output": "default", |
| "enable_metadata_routing": False, |
| "skip_parameter_validation": False, |
| } |
| _threadlocal = threading.local() |
|
|
|
|
| def _get_threadlocal_config(): |
| """Get a threadlocal **mutable** configuration. If the configuration |
| does not exist, copy the default global configuration.""" |
| if not hasattr(_threadlocal, "global_config"): |
| _threadlocal.global_config = _global_config.copy() |
| return _threadlocal.global_config |
|
|
|
|
| def get_config(): |
| """Retrieve current values for configuration set by :func:`set_config`. |
| |
| Returns |
| ------- |
| config : dict |
| Keys are parameter names that can be passed to :func:`set_config`. |
| |
| See Also |
| -------- |
| config_context : Context manager for global scikit-learn configuration. |
| set_config : Set global scikit-learn configuration. |
| |
| Examples |
| -------- |
| >>> import sklearn |
| >>> config = sklearn.get_config() |
| >>> config.keys() |
| dict_keys([...]) |
| """ |
| |
| |
| return _get_threadlocal_config().copy() |
|
|
|
|
| def set_config( |
| assume_finite=None, |
| working_memory=None, |
| print_changed_only=None, |
| display=None, |
| pairwise_dist_chunk_size=None, |
| enable_cython_pairwise_dist=None, |
| array_api_dispatch=None, |
| transform_output=None, |
| enable_metadata_routing=None, |
| skip_parameter_validation=None, |
| ): |
| """Set global scikit-learn configuration. |
| |
| .. versionadded:: 0.19 |
| |
| Parameters |
| ---------- |
| assume_finite : bool, default=None |
| If True, validation for finiteness will be skipped, |
| saving time, but leading to potential crashes. If |
| False, validation for finiteness will be performed, |
| avoiding error. Global default: False. |
| |
| .. versionadded:: 0.19 |
| |
| working_memory : int, default=None |
| If set, scikit-learn will attempt to limit the size of temporary arrays |
| to this number of MiB (per job when parallelised), often saving both |
| computation time and memory on expensive operations that can be |
| performed in chunks. Global default: 1024. |
| |
| .. versionadded:: 0.20 |
| |
| print_changed_only : bool, default=None |
| If True, only the parameters that were set to non-default |
| values will be printed when printing an estimator. For example, |
| ``print(SVC())`` while True will only print 'SVC()' while the default |
| behaviour would be to print 'SVC(C=1.0, cache_size=200, ...)' with |
| all the non-changed parameters. |
| |
| .. versionadded:: 0.21 |
| |
| display : {'text', 'diagram'}, default=None |
| If 'diagram', estimators will be displayed as a diagram in a Jupyter |
| lab or notebook context. If 'text', estimators will be displayed as |
| text. Default is 'diagram'. |
| |
| .. versionadded:: 0.23 |
| |
| pairwise_dist_chunk_size : int, default=None |
| The number of row vectors per chunk for the accelerated pairwise- |
| distances reduction backend. Default is 256 (suitable for most of |
| modern laptops' caches and architectures). |
| |
| Intended for easier benchmarking and testing of scikit-learn internals. |
| End users are not expected to benefit from customizing this configuration |
| setting. |
| |
| .. versionadded:: 1.1 |
| |
| enable_cython_pairwise_dist : bool, default=None |
| Use the accelerated pairwise-distances reduction backend when |
| possible. Global default: True. |
| |
| Intended for easier benchmarking and testing of scikit-learn internals. |
| End users are not expected to benefit from customizing this configuration |
| setting. |
| |
| .. versionadded:: 1.1 |
| |
| array_api_dispatch : bool, default=None |
| Use Array API dispatching when inputs follow the Array API standard. |
| Default is False. |
| |
| See the :ref:`User Guide <array_api>` for more details. |
| |
| .. versionadded:: 1.2 |
| |
| transform_output : str, default=None |
| Configure output of `transform` and `fit_transform`. |
| |
| See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` |
| for an example on how to use the API. |
| |
| - `"default"`: Default output format of a transformer |
| - `"pandas"`: DataFrame output |
| - `"polars"`: Polars output |
| - `None`: Transform configuration is unchanged |
| |
| .. versionadded:: 1.2 |
| .. versionadded:: 1.4 |
| `"polars"` option was added. |
| |
| enable_metadata_routing : bool, default=None |
| Enable metadata routing. By default this feature is disabled. |
| |
| Refer to :ref:`metadata routing user guide <metadata_routing>` for more |
| details. |
| |
| - `True`: Metadata routing is enabled |
| - `False`: Metadata routing is disabled, use the old syntax. |
| - `None`: Configuration is unchanged |
| |
| .. versionadded:: 1.3 |
| |
| skip_parameter_validation : bool, default=None |
| If `True`, disable the validation of the hyper-parameters' types and values in |
| the fit method of estimators and for arguments passed to public helper |
| functions. It can save time in some situations but can lead to low level |
| crashes and exceptions with confusing error messages. |
| |
| Note that for data parameters, such as `X` and `y`, only type validation is |
| skipped but validation with `check_array` will continue to run. |
| |
| .. versionadded:: 1.3 |
| |
| See Also |
| -------- |
| config_context : Context manager for global scikit-learn configuration. |
| get_config : Retrieve current values of the global configuration. |
| |
| Examples |
| -------- |
| >>> from sklearn import set_config |
| >>> set_config(display='diagram') # doctest: +SKIP |
| """ |
| local_config = _get_threadlocal_config() |
|
|
| if assume_finite is not None: |
| local_config["assume_finite"] = assume_finite |
| if working_memory is not None: |
| local_config["working_memory"] = working_memory |
| if print_changed_only is not None: |
| local_config["print_changed_only"] = print_changed_only |
| if display is not None: |
| local_config["display"] = display |
| if pairwise_dist_chunk_size is not None: |
| local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size |
| if enable_cython_pairwise_dist is not None: |
| local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist |
| if array_api_dispatch is not None: |
| from .utils._array_api import _check_array_api_dispatch |
|
|
| _check_array_api_dispatch(array_api_dispatch) |
| local_config["array_api_dispatch"] = array_api_dispatch |
| if transform_output is not None: |
| local_config["transform_output"] = transform_output |
| if enable_metadata_routing is not None: |
| local_config["enable_metadata_routing"] = enable_metadata_routing |
| if skip_parameter_validation is not None: |
| local_config["skip_parameter_validation"] = skip_parameter_validation |
|
|
|
|
| @contextmanager |
| def config_context( |
| *, |
| assume_finite=None, |
| working_memory=None, |
| print_changed_only=None, |
| display=None, |
| pairwise_dist_chunk_size=None, |
| enable_cython_pairwise_dist=None, |
| array_api_dispatch=None, |
| transform_output=None, |
| enable_metadata_routing=None, |
| skip_parameter_validation=None, |
| ): |
| """Context manager for global scikit-learn configuration. |
| |
| Parameters |
| ---------- |
| assume_finite : bool, default=None |
| If True, validation for finiteness will be skipped, |
| saving time, but leading to potential crashes. If |
| False, validation for finiteness will be performed, |
| avoiding error. If None, the existing value won't change. |
| The default value is False. |
| |
| working_memory : int, default=None |
| If set, scikit-learn will attempt to limit the size of temporary arrays |
| to this number of MiB (per job when parallelised), often saving both |
| computation time and memory on expensive operations that can be |
| performed in chunks. If None, the existing value won't change. |
| The default value is 1024. |
| |
| print_changed_only : bool, default=None |
| If True, only the parameters that were set to non-default |
| values will be printed when printing an estimator. For example, |
| ``print(SVC())`` while True will only print 'SVC()', but would print |
| 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters |
| when False. If None, the existing value won't change. |
| The default value is True. |
| |
| .. versionchanged:: 0.23 |
| Default changed from False to True. |
| |
| display : {'text', 'diagram'}, default=None |
| If 'diagram', estimators will be displayed as a diagram in a Jupyter |
| lab or notebook context. If 'text', estimators will be displayed as |
| text. If None, the existing value won't change. |
| The default value is 'diagram'. |
| |
| .. versionadded:: 0.23 |
| |
| pairwise_dist_chunk_size : int, default=None |
| The number of row vectors per chunk for the accelerated pairwise- |
| distances reduction backend. Default is 256 (suitable for most of |
| modern laptops' caches and architectures). |
| |
| Intended for easier benchmarking and testing of scikit-learn internals. |
| End users are not expected to benefit from customizing this configuration |
| setting. |
| |
| .. versionadded:: 1.1 |
| |
| enable_cython_pairwise_dist : bool, default=None |
| Use the accelerated pairwise-distances reduction backend when |
| possible. Global default: True. |
| |
| Intended for easier benchmarking and testing of scikit-learn internals. |
| End users are not expected to benefit from customizing this configuration |
| setting. |
| |
| .. versionadded:: 1.1 |
| |
| array_api_dispatch : bool, default=None |
| Use Array API dispatching when inputs follow the Array API standard. |
| Default is False. |
| |
| See the :ref:`User Guide <array_api>` for more details. |
| |
| .. versionadded:: 1.2 |
| |
| transform_output : str, default=None |
| Configure output of `transform` and `fit_transform`. |
| |
| See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py` |
| for an example on how to use the API. |
| |
| - `"default"`: Default output format of a transformer |
| - `"pandas"`: DataFrame output |
| - `"polars"`: Polars output |
| - `None`: Transform configuration is unchanged |
| |
| .. versionadded:: 1.2 |
| .. versionadded:: 1.4 |
| `"polars"` option was added. |
| |
| enable_metadata_routing : bool, default=None |
| Enable metadata routing. By default this feature is disabled. |
| |
| Refer to :ref:`metadata routing user guide <metadata_routing>` for more |
| details. |
| |
| - `True`: Metadata routing is enabled |
| - `False`: Metadata routing is disabled, use the old syntax. |
| - `None`: Configuration is unchanged |
| |
| .. versionadded:: 1.3 |
| |
| skip_parameter_validation : bool, default=None |
| If `True`, disable the validation of the hyper-parameters' types and values in |
| the fit method of estimators and for arguments passed to public helper |
| functions. It can save time in some situations but can lead to low level |
| crashes and exceptions with confusing error messages. |
| |
| Note that for data parameters, such as `X` and `y`, only type validation is |
| skipped but validation with `check_array` will continue to run. |
| |
| .. versionadded:: 1.3 |
| |
| Yields |
| ------ |
| None. |
| |
| See Also |
| -------- |
| set_config : Set global scikit-learn configuration. |
| get_config : Retrieve current values of the global configuration. |
| |
| Notes |
| ----- |
| All settings, not just those presently modified, will be returned to |
| their previous values when the context manager is exited. |
| |
| Examples |
| -------- |
| >>> import sklearn |
| >>> from sklearn.utils.validation import assert_all_finite |
| >>> with sklearn.config_context(assume_finite=True): |
| ... assert_all_finite([float('nan')]) |
| >>> with sklearn.config_context(assume_finite=True): |
| ... with sklearn.config_context(assume_finite=False): |
| ... assert_all_finite([float('nan')]) |
| Traceback (most recent call last): |
| ... |
| ValueError: Input contains NaN... |
| """ |
| old_config = get_config() |
| set_config( |
| assume_finite=assume_finite, |
| working_memory=working_memory, |
| print_changed_only=print_changed_only, |
| display=display, |
| pairwise_dist_chunk_size=pairwise_dist_chunk_size, |
| enable_cython_pairwise_dist=enable_cython_pairwise_dist, |
| array_api_dispatch=array_api_dispatch, |
| transform_output=transform_output, |
| enable_metadata_routing=enable_metadata_routing, |
| skip_parameter_validation=skip_parameter_validation, |
| ) |
|
|
| try: |
| yield |
| finally: |
| set_config(**old_config) |
|
|