| |
| |
|
|
| from copy import deepcopy |
|
|
| from ..base import BaseEstimator |
| from ..exceptions import NotFittedError |
| from ..utils import get_tags |
| from ..utils.metaestimators import available_if |
| from ..utils.validation import check_is_fitted |
|
|
|
|
| def _estimator_has(attr): |
| """Check that final_estimator has `attr`. |
| |
| Used together with `available_if`. |
| """ |
|
|
| def check(self): |
| |
| getattr(self.estimator, attr) |
| return True |
|
|
| return check |
|
|
|
|
| class FrozenEstimator(BaseEstimator): |
| """Estimator that wraps a fitted estimator to prevent re-fitting. |
| |
| This meta-estimator takes an estimator and freezes it, in the sense that calling |
| `fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled. |
| All other methods are delegated to the original estimator and original estimator's |
| attributes are accessible as well. |
| |
| This is particularly useful when you have a fitted or a pre-trained model as a |
| transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this |
| step. |
| |
| Parameters |
| ---------- |
| estimator : estimator |
| The estimator which is to be kept frozen. |
| |
| See Also |
| -------- |
| None: No similar entry in the scikit-learn documentation. |
| |
| Examples |
| -------- |
| >>> from sklearn.datasets import make_classification |
| >>> from sklearn.frozen import FrozenEstimator |
| >>> from sklearn.linear_model import LogisticRegression |
| >>> X, y = make_classification(random_state=0) |
| >>> clf = LogisticRegression(random_state=0).fit(X, y) |
| >>> frozen_clf = FrozenEstimator(clf) |
| >>> frozen_clf.fit(X, y) # No-op |
| FrozenEstimator(estimator=LogisticRegression(random_state=0)) |
| >>> frozen_clf.predict(X) # Predictions from `clf.predict` |
| array(...) |
| """ |
|
|
| def __init__(self, estimator): |
| self.estimator = estimator |
|
|
| @available_if(_estimator_has("__getitem__")) |
| def __getitem__(self, *args, **kwargs): |
| """__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \ |
| :class:`~sklearn.compose.ColumnTransformer`. |
| """ |
| return self.estimator.__getitem__(*args, **kwargs) |
|
|
| def __getattr__(self, name): |
| |
| |
| if name in ["fit_predict", "fit_transform"]: |
| raise AttributeError(f"{name} is not available for frozen estimators.") |
| return getattr(self.estimator, name) |
|
|
| def __sklearn_clone__(self): |
| return self |
|
|
| def __sklearn_is_fitted__(self): |
| try: |
| check_is_fitted(self.estimator) |
| return True |
| except NotFittedError: |
| return False |
|
|
| def fit(self, X, y, *args, **kwargs): |
| """No-op. |
| |
| As a frozen estimator, calling `fit` has no effect. |
| |
| Parameters |
| ---------- |
| X : object |
| Ignored. |
| |
| y : object |
| Ignored. |
| |
| *args : tuple |
| Additional positional arguments. Ignored, but present for API compatibility |
| with `self.estimator`. |
| |
| **kwargs : dict |
| Additional keyword arguments. Ignored, but present for API compatibility |
| with `self.estimator`. |
| |
| Returns |
| ------- |
| self : object |
| Returns the instance itself. |
| """ |
| check_is_fitted(self.estimator) |
| return self |
|
|
| def set_params(self, **kwargs): |
| """Set the parameters of this estimator. |
| |
| The only valid key here is `estimator`. You cannot set the parameters of the |
| inner estimator. |
| |
| Parameters |
| ---------- |
| **kwargs : dict |
| Estimator parameters. |
| |
| Returns |
| ------- |
| self : FrozenEstimator |
| This estimator. |
| """ |
| estimator = kwargs.pop("estimator", None) |
| if estimator is not None: |
| self.estimator = estimator |
| if kwargs: |
| raise ValueError( |
| "You cannot set parameters of the inner estimator in a frozen " |
| "estimator since calling `fit` has no effect. You can use " |
| "`frozenestimator.estimator.set_params` to set parameters of the inner " |
| "estimator." |
| ) |
|
|
| def get_params(self, deep=True): |
| """Get parameters for this estimator. |
| |
| Returns a `{"estimator": estimator}` dict. The parameters of the inner |
| estimator are not included. |
| |
| Parameters |
| ---------- |
| deep : bool, default=True |
| Ignored. |
| |
| Returns |
| ------- |
| params : dict |
| Parameter names mapped to their values. |
| """ |
| return {"estimator": self.estimator} |
|
|
| def __sklearn_tags__(self): |
| tags = deepcopy(get_tags(self.estimator)) |
| tags._skip_test = True |
| return tags |
|
|